How to unit test PySpark

Recently, I came across an interesting problem: how to speed up the feedback loop while maintaining a PySpark DAG. Of course, I could just run the Spark Job and look at the data, but that is just not practical.

The DAG needed a few hours to finish. Because of that, I could make and verify two code changes a day. That is ridiculous.

Making DAGs testable

Fortunately, it is trivial to turn Spark DAGs into testable code. Let’s imagine that I want to test the following code:

some_input_data = spark.table('source.some_table') \
    .where(concat('year', 'month', 'day').between(START_DATE, END_DATE)) \
    .where(col('some_column') != 'value') \
    .where(col('another_column') == 123) \
    .where(col('something_that_cant_be_null').isNotNull()) \
    .select('column_1', 'column_2', 'column_3', 'column_4', 'column_5')

some_filtered_data = some_input_data.select('column_1', 'column_5').distinct() \
    .groupBy('column_1').count().where(col('count') == 1)

some_filtered_data.repartition(100) \
    .write.saveAsTable('destination.table_name, path='s3://bucket/table_name', mode='overwrite')

Now, imagine that the some_filtered_data is a result of some long and complicated computation, not just some grouping and filtering like in the example.

First, I copy all of the code into a new function:

def do_the_calculations():
    some_input_data = spark.table('source.some_table') \
        .where(concat('year', 'month', 'day').between(START_DATE, END_DATE)) \
        .where(col('some_column') != 'value') \
        .where(col('another_column') == 123) \
        .where(col('something_that_cant_be_null').isNotNull()) \
        .select('column_1', 'column_2', 'column_3', 'column_4', 'column_5')

    some_filtered_data = some_input_data.select('column_1', 'column_5').distinct() \
        .groupBy('column_1').count().where(col('count') == 1)

    some_filtered_data.repartition(100) \
        .write.saveAsTable('destination.table_name, path='s3://bucket/table_name', mode='overwrite')

I remove the parts that handle side effects (reading from a database and writing data to the output location). I end up with a code snippet that does not work:

def do_the_calculations():
    some_filtered_data = some_input_data.select('column_1', 'column_5').distinct() \
        .groupBy('column_1').count().where(col('count') == 1)

After that, I add the missing variables as function parameters so that I can pass the input to the function.

def do_the_calculations(some_input_data):
    some_filtered_data = some_input_data.select('column_1', 'column_5').distinct() \
        .groupBy('column_1').count().where(col('count') == 1)

The DAG needs to produce output somehow. Therefore, I return the filtered data from the function.

def do_the_calculations(some_input_data):
    some_filtered_data = some_input_data.select('column_1', 'column_5').distinct() \
        .groupBy('column_1').count().where(col('count') == 1)

    return some_filtered_data

Note that I have just made a pure function. It is trivial to test such subroutines because their output depends only on the input parameters.

At this point, I replace the processing part of my DAG with the function call.

some_input_data = spark.table('source.some_table') \
    .where(concat('year', 'month', 'day').between(START_DATE, END_DATE)) \
    .where(col('some_column') != 'value') \
    .where(col('another_column') == 123) \
    .where(col('something_that_cant_be_null').isNotNull()) \
    .select('column_1', 'column_2', 'column_3', 'column_4', 'column_5')

some_filtered_data = do_the_calculations(some_input_data)

some_filtered_data.repartition(100) \
    .write.saveAsTable('destination.table_name, path='s3://bucket/table_name', mode='overwrite')

Writing tests

Perfect. Now, I am going to use the Python unittest module to write tests for that pure function.

First, I have to extend the UnitTest class and initialize it correctly. Note that the superclass constructor needs two parameters. In the constructor, I start the local Spark context and define the input schemas.

import unittest
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from module_under_test import do_the_calculations

class SomeUnitTest(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args,**kwargs)

        self.spark = SparkSession.builder \
            .master('local[2]') \
            .appName('some-test-app') \
            .enableHiveSupport() \
            .getOrCreate()

        self.input_schema = StructType([
            StructField('column_1', IntegerType()),
            StructField('column_2', IntegerType()),
            StructField('column_3', IntegerType()),
            StructField('column_4', IntegerType()),
            StructField('column_5', IntegerType())
        ])

Second, I write the first test function. It is required to begin the function name with the “test” prefix!

def test_here_put_the_test_description(self):
    test_input = self.spark.createDataFrame([
        [1, 2, 3, 4, 5],
        [6, 7, 8, 9, 10]
    ], schema = self.input_schema)

    expected_result = [
        Row(column_1 = 1, count = 1),
        Row(column_1 = 6, count = 1)
    ]

    test_result = do_the_calculations(test_input)
    actual_result = test_result.collect()

    self.assertEquals(expected_result, actual_result)

In the first part of the test function, I prepare a Spark DataFrame with my test input data and a collection that contains the expected output. After that, I call the function under test and compare the actual output with the expected values.

Problems and limitations

It is not a perfect way to test Spark. For sure, it shortens the feedback loop, but it still needs to start and initialize the Spark context, so it takes a minute or two to run such a test.

On the one hand, it is painfully slow if you compare a unit test running for two minutes with tests of any Scala, Java, or Python service that don’t run Spark. On the other hand, if it takes four hours to run that Spark DAG on production data, a test running for a minute is 240x better.

I recommend remembering about sorting the output data before comparing it with the expected output.

If I have only a few values in the expected collection, I prefer to convert them all into a dictionary and compare every row one by one because it gives me more user-friendly error messages in case of a test failure.

Older post

How to speed up a PySpark job

Why one Spark executor is running much longer than others and what you can do about it

Newer post

How to send metrics to AWS CloudWatch from custom Python code

How to use boto3 to send custom metrics to AWS CloudWatch from Python