Implement data quality checks in your pySpark data pipelines

Building data quality checks in your pySpark data pipelines

Data quality is a rather critical part of any production data pipeline. In order to provide accurate SLA metrics and to ensure that the data is correct, it is important to have a way to validate the data and report the metrics for further analysis. In this post, we will look at how to build data quality checks in your pySpark data pipelines.

Exploring Delta Live Tables

Delta Live Tables is a new feature in Databricks that allows users to build reliable data pipelines with built-in data quality metrics and monitoring. It is a new abstraction on top of Delta Lake that allows users to query the data using streaming live tables. The data is updated in real-time as the underlying data changes. What caught my eye was the data quality capabilities that the users can specify on dataset level. Using python decorators we can specify @expect_all, @expect_all_or_drop, and @expect_all_or_fail expectations that accept a python dictionary as an argument, where the key is the expectation name and the value is the expectation constraint. Example:

@dlt.expect("valid timestamp", "col(“timestamp”) > '2012-01-01'")
@dlt.expect_or_drop("valid_current_page", "current_page_id IS NOT NULL AND current_page_title IS NOT NULL")
@dlt.expect_or_fail("valid_count", "count > 0")

Metrics of clean records and failed records are automatically collected and stored in the Delta Live Table metadata, so the users can set up alerts and monitor the data quality of their pipelines.

Delta Live Tables however still face quite some limitations and are not yet ready for production use. Some limitations include: 1. The data quality checks are only available for streaming live tables, not for batch tables. We can still create streaming tables from batch tables, but if the version of your data is changing the pipeline will fail. 2. Lack of testing capabilities. There is no way to test the data quality checks in a local environment because dlt package is available only in Databricks runtime. 3. Lack of documentation. The documentation is very limited and it is not clear how to use the data quality checks. Currently only python and SQL API are supported. 4. Setting up DLT job doesn't support all the parameters that are available in the Databricks job.

Building your own data quality checks as python decorators

In order to overcome the limitations of Delta Live Tables, we can build our own data quality checks as python decorators. The idea is to create a decorator that will accept a python list of arguments, which will be constraints that we will apply for a determined column. We will collect all the necessary metrics and store them as part of the Delta Lake metadata.

We will start by building two simple conditions for our data, uniqueness and filtering based on a condition.

from abc import ABC, abstractmethod

class ColumnCondition(ABC):
    @abstractmethod
    def get_cols(self):
        pass


class UniqueCondition(ColumnCondition):

    def __init__(self, col):
        self.col = col

    def get_cols(self):
        return self.col


class FilterCondition(ColumnCondition):

    def __init__(self, left_col, right_col):
        self.left_col = left_col
        self.right_col = right_col

    def get_cols(self):
        return self.left_col, self.right_col


def is_not_null(col):
    return FilterCondition(col + " is not null", col + " is null")

def is_unique(col):
    return UniqueCondition(col)

The main idea is to use a function as a decorator argument using a certain column, which will return a condition object. We can use the condition object to pattern match and apply the specific function depending on the condition type. We start by creating a simple python decorator using functools wraps:

    def expect_or_drop(self, conditions: List[FilterCondition]):
        def decorator(function):
            @wraps(function)
            def wrapper(*args, **kwargs):
                retval = function(*args, **kwargs)
                # apply conditions
                return retval
            return wrapper
        return decorator

We will create an Expectations class that will contain all the data quality checks. The rsd arguments represents the maximum relative standard deviation allowed for the approx_count_distinct_functions. Read more here.

class Expectations:

    def __init__(self, spark: SparkSession, rsd=0.05):
        self.spark = spark
        self.schema = StructType([StructField("condition", StringType(), True),
                                  StructField("dropped_records", IntegerType(), True),
                                  StructField("clean_records", IntegerType(), True)])
        emptyRDD = spark.sparkContext.emptyRDD()
        self.metrics = spark.createDataFrame(emptyRDD, schema=self.schema)
        self.rsd = rsd

The metrics dataframe will contain the metrics for each data quality check. We can proceed to create our filtering and uniqueness checks:

    def apply_condition(self, dataframe, condition):
        if isinstance(condition, FilterCondition):
            return self.filter_condition(dataframe, condition.get_cols())
        elif isinstance(condition, UniqueCondition):
            return self.is_unique_extend(dataframe, condition.get_cols())
        return dataframe

    def filter_condition(self, dataframe: DataFrame, left_right) -> DataFrame:
        left, right = left_right
        total_records = dataframe.count()
        dropped_records = dataframe.filter(right).count()
        df = self.spark.createDataFrame([(left, dropped_records, (total_records - dropped_records))], schema=self.schema)
        self.metrics = self.metrics.unionAll(df)
        return dataframe.filter(left)

    def is_unique_extend(self, dataframe: DataFrame, col) -> DataFrame:
        total_records = dataframe.select(F.col(col)).count()
        distinct_records = dataframe.select(F.approx_count_distinct(col, self.rsd)).collect()[0][0]
        dropped_records = total_records - distinct_records
        df = self.spark.createDataFrame([(col + " is unique", dropped_records, distinct_records)], schema=self.schema)
        self.metrics = self.metrics.unionAll(df)
        return dataframe.dropDuplicates([col])

In order to apply the conditions we will use the apply_condition function to every condition in the list. In order to do that we will use the functools reduce function as foldLeft:

   foldl = lambda func, acc, xs: reduce(func, xs, acc)
   @wraps(function)
   def wrapper(*args, **kwargs):
        retval = function(*args, **kwargs)
        return foldl(self.apply_condition, retval, conditions)

Wrapping it all together:

from abc import ABC, abstractmethod
from functools import reduce, wraps
from typing import List

import pyspark.sql.functions as F
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import IntegerType, StringType, StructField, StructType

foldl = lambda func, acc, xs: reduce(func, xs, acc)


class ColumnCondition(ABC):
    @abstractmethod
    def get_cols(self):
        pass


class UniqueCondition(ColumnCondition):

    def __init__(self, col):
        self.col = col

    def get_cols(self):
        return self.col


class FilterCondition(ColumnCondition):

    def __init__(self, left_col, right_col):
        self.left_col = left_col
        self.right_col = right_col

    def get_cols(self):
        return self.left_col, self.right_col


def is_not_null(col):
    return FilterCondition(col + " is not null", col + " is null")


def is_unique(col):
    return UniqueCondition(col)


class Expectations:

    def __init__(self, spark: SparkSession, rsd=0.05):
        self.spark = spark
        self.schema = StructType([StructField("condition", StringType(), True),
                                  StructField("dropped_records", IntegerType(), True),
                                  StructField("clean_records", IntegerType(), True)])
        emptyRDD = spark.sparkContext.emptyRDD()
        self.metrics = spark.createDataFrame(emptyRDD, schema=self.schema)
        self.rsd = rsd

    def expect_or_drop(self, conditions: List[FilterCondition]):
        def decorator(function):
            @wraps(function)
            def wrapper(*args, **kwargs):
                retval = function(*args, **kwargs)
                return foldl(self.apply_condition, retval, conditions)
            return wrapper
        return decorator

    def apply_condition(self, dataframe, condition):
        if isinstance(condition, FilterCondition):
            return self.filter_condition(dataframe, condition.get_cols())
        elif isinstance(condition, UniqueCondition):
            return self.is_unique_extend(dataframe, condition.get_cols())
        return dataframe

    def filter_condition(self, dataframe: DataFrame, left_right) -> DataFrame:
        left, right = left_right
        total_records = dataframe.count()
        dropped_records = dataframe.filter(right).count()
        df = self.spark.createDataFrame([(left, dropped_records, (total_records - dropped_records))], schema=self.schema)
        self.metrics = self.metrics.unionAll(df)
        return dataframe.filter(left)

    def is_unique_extend(self, dataframe: DataFrame, col) -> DataFrame:
        total_records = dataframe.select(F.col(col)).count()
        distinct_records = dataframe.select(F.approx_count_distinct(col, self.rsd)).collect()[0][0]
        dropped_records = total_records - distinct_records
        df = self.spark.createDataFrame([(col + " is unique", dropped_records, distinct_records)], schema=self.schema)
        self.metrics = self.metrics.unionAll(df)
        return dataframe.dropDuplicates([col])

We can proceed with writing a simple test to check the results and a metrics with a simple input dataframe:

def test_multiple_conditions(spark):
    df_1 = spark.createDataFrame([("row1", 1), ("row1", 2), (None, 3)], ["row", "row_number"])
    expectation = Expectations(spark)

    @expectation.expect_or_drop([is_not_null("row"), is_unique("row")])
    def read_dataframe(df):
        return df

    result = read_dataframe(df_1)
    print(result.collect())
    print(expectation.metrics.collect())

The console will print:

[Row(row='row1', row_number=1)]
[Row(condition='row is not null', dropped_records=1, clean_records=2), Row(condition='row is unique',
dropped_records=1, clean_records=1)]

We can extend and add some plotting functions as well to our Expectations class.

    def plot_pie_with_total(self, figsize=(10, 10)):
        labels = ["clean_records", "dropped_records"]
        df = self.metrics.toPandas()
        size = len(df.index)
        if size == 1:
            fig, axs = plt.subplots(1)
            axs.pie(df.iloc[0][labels], labels=labels, autopct='%1.1f%%')
            axs.set_title(df.iloc[0]["condition"])
        else:
            fig, axs = plt.subplots(1, size, figsize=figsize)
            for i in range(size):
                axs[i].pie(df.iloc[i][labels], labels=labels, autopct='%1.1f%%')
                axs[i].set_title(df.iloc[i]["condition"])
        plt.show()

That will generate a pie chart with the metrics for every condition: Image Name

Happy coding and stay safe!

social