Scripting PySpark DataFrames

Developing Spark applications means dealing with Spark DataFrames. These objects are in-memory data structures accessible via various APIs but locally scoped. Sometimes we need to use them outside the runtime environment. Scripting a dataframe as the collection of Python commands that fully reproduce this object is a possible and occasionally preferred solution.

The use cases

The most common example is a debugging of a production issue. So, think of a Data Processing framework built around Spark. It runs in a production environment, and you spot that one of the workflows fails. Further checks show that there is a flaw in the processing logic. It is handy to identify those few rows on which the processing crashes, create a dataframe and then transfer it to your development environment.

The transported dataframe might then be used for the following:

  1. Debugging the issue in the development/feature environment
  2. Creating a unit test for preventing such kind of errors
  3. Creating a default sample data when the new development/feature environment must be deployed

Why not just save the dataframe as parquet?

Some will ask: “Why bother with scripting if the dataframe can be dumped as a parquet file?”.

The quick answer is that parquet is a binary format; it is very efficient but hard to be viewed and edit using text tools.

The Python script of the dataframe contains both data and schema and can be edited in place when needed, bringing flexibility. The dataframe script can eventually be turned into a parameterized Pytest fixture to enable multiple testing scenarios without crafting a set of binary files for each one.

Another aspect is Git. Git handles plain text/code better than binary content. The data team might use small parquet/delta files as the input for unit tests. The peer reviewer cannot see what changed if its data content must be updated.

And as the analogy from the database world: there are scenarios when binary backup might be the most robust and preferred solution. Still, sometimes a script with a table definition and few insert statements is better.

Scripting a dataframe in practice

Before we proceed with an example, let’s define a sample dataframe with a complex enough schema:

df = spark.sql(
    """ 
    SELECT 
        1 AS ID, 
        "John Doe" AS Name, 
        STRUCT(
            "line1" AS Line1, 
            "line2" AS Line2
        ) AS Address, 
        TO_DATE("1980-01-01") AS BirthDate,
        ARRAY(
            STRUCT(
                012345678 AS AccountNumber, 
                "NL" AS CountryCode,
                1005.50 AS Balance
            ),
            STRUCT(
                012345678 AS AccountNumber, 
                "UK" AS CountryCode,
                NULL AS Balance
            )
        ) AS BankAccounts
    """
)
df.printSchema()


"""
root
 |-- ID: integer (nullable = false)
 |-- Name: string (nullable = false)
 |-- Address: struct (nullable = false)
 |    |-- Line1: string (nullable = false)
 |    |-- Line2: string (nullable = false)
 |-- BirthDate: date (nullable = true)
 |-- BankAccounts: array (nullable = false)
 |    |-- element: struct (containsNull = false)
 |    |    |-- AccountNumber: integer (nullable = false)
 |    |    |-- CountryCode: string (nullable = false)
 |    |    |-- Balance: decimal(6,2) (nullable = true)
"""

The dataframe can be one-line scripted using a custom helper function script_dataframe() which returns generated code as the string output:

dataframe_script = script_dataframe(df)
print(dataframe_script)

The content of the dataframe_script is a Python snippet that is small enough to be copy-pasted, edited if needed (to remove PII data, for instance), and executed in any other environment. It contains __data, a collection of rows, and __schema that holds an exact schema of the scripted dataframe:

from pyspark.sql import Row
import datetime
from decimal import Decimal
from pyspark.sql.types import *

# Scripted data and schema:
__data = [
    Row(
        ID=1,
        Name="John Doe",
        Address=Row(Line1="line1", Line2="line2"),
        BirthDate=datetime.date(1980, 1, 1),
        BankAccounts=[
            Row(AccountNumber=12345678, CountryCode="NL", Balance=Decimal("1005.50")),
            Row(AccountNumber=12345678, CountryCode="UK", Balance=None),
        ],
    )
]
__schema = StructType.fromJson(
    {
        "type": "struct",
        "fields": [
            {"name": "ID", "type": "integer", "nullable": False, "metadata": {}},
            {"name": "Name", "type": "string", "nullable": False, "metadata": {}},
            {
                "name": "Address",
                "type": {
                    "type": "struct",
                    "fields": [
                        {"name": "Line1", "type": "string", "nullable": False, "metadata": {}},
                        {"name": "Line2", "type": "string", "nullable": False, "metadata": {}},
                    ],
                },
                "nullable": False,
                "metadata": {},
            },
            {"name": "BirthDate", "type": "date", "nullable": True, "metadata": {}},
            {
                "name": "BankAccounts",
                "type": {
                    "type": "array",
                    "elementType": {
                        "type": "struct",
                        "fields": [
                            {
                                "name": "AccountNumber",
                                "type": "integer",
                                "nullable": False,
                                "metadata": {},
                            },
                            {
                                "name": "CountryCode",
                                "type": "string",
                                "nullable": False,
                                "metadata": {},
                            },
                            {
                                "name": "Balance",
                                "type": "decimal(6,2)",
                                "nullable": True,
                                "metadata": {},
                            },
                        ],
                    },
                    "containsNull": False,
                },
                "nullable": False,
                "metadata": {},
            },
        ],
    }
)

outcome_dataframe = spark.createDataFrame(__data, __schema)

Let’s verify if the script works correctly by running it and exploring the dataframe content:

outcome_dataframe.show(truncate=False)

"""
+---+--------+--------------+----------+-----------------------------------------------+
|ID |Name    |Address       |BirthDate |BankAccounts                                   |
+---+--------+--------------+----------+-----------------------------------------------+
|1  |John Doe|{line1, line2}|1980-01-01|[{12345678, NL, 1005.50}, {12345678, UK, null}]|
+---+--------+--------------+----------+-----------------------------------------------+
"""

outcome_dataframe.printSchema()

"""
root
 |-- ID: integer (nullable = false)
 |-- Name: string (nullable = false)
 |-- Address: struct (nullable = false)
 |    |-- Line1: string (nullable = false)
 |    |-- Line2: string (nullable = false)
 |-- BirthDate: date (nullable = true)
 |-- BankAccounts: array (nullable = false)
 |    |-- element: struct (containsNull = false)
 |    |    |-- AccountNumber: integer (nullable = false)
 |    |    |-- CountryCode: string (nullable = false)
 |    |    |-- Balance: decimal(6,2) (nullable = true)
"""

Limitations

The logic of script_dataframe() internally relies on DataFrame.collect(), which means it gathers all scripted rows in Spark’s driver. Therefore, the function is intended to dump relatively small dataframes. By default, it limits the number of rows to 20, but that limit can be increased using an extra parameter limit_rows.

Source code

The following code was tested and runs in PySpark and Databricks:

from pprint import pformat
from pyspark.sql import DataFrame

def script_dataframe(
    input_dataframe: DataFrame, limit_rows: int = 20
) -> str:
    """Generate a script to recreate the dataframe
    The script includes the schema and the data

    Args:
        input_dataframe (DataFrame): Input spark dataframe
        limit_rows (int, optional): Prevents too large dataframe to be processed. \
            Defaults to 20.

    Raises:
        ValueError: when the dataframe is too large (by default > 20 rows)

    Returns:
        The script to recreate the dataframe

    Examples:
        >>> script = script_dataframe(input_dataframe=df)
        >>> print(script)
    """

    if input_dataframe.count() > limit_rows:
        raise ValueError(
            "This method is limited to script up "
            f"to {limit_rows} row(s) per call"
        )
    
    __data = pformat(input_dataframe.collect())

    __schema = input_dataframe.schema.jsonValue()

    __script_lines = [
        "from pyspark.sql import Row",
        "import datetime",
        "from decimal import Decimal",
        "from pyspark.sql.types import *",
        "",
        "# Scripted data and schema:",
        f"__data = {__data}",
        f"__schema = StructType.fromJson({__schema})",
        "",
        "outcome_dataframe = spark.createDataFrame(__data, __schema)",
    ]

    __final_script = "\n".join(__script_lines)

    return __final_script

Final words

While the PySpark dataframe scripting is not (yet) part of the native functionality, it is a relatively simple operation. It can be done via the custom helper function script_dataframe() I shared above. It helps me to craft unit tests faster and to make the processing frameworks more reliable. I hope it will also help you in your PySpark journey.