Henry8
Henry8

Reputation: 110

When testing code with pyspark dataframe: how to mock up .repartition() chained function?

I have code using pyspark library and I want to test it with pytest

However, I want to mock up .repartition() method on dataframes when running tests

  1. Suppose that code I want to test is a pyspark chained function like below
def transform(df: pyspark.sql.DataFrame):
    return (
       df
       .repartition("id")
       .groupby("id")
       .sum("quantity")
    )
  1. Currently my testing function looks like
@pytest.mark.parametrize("df, expected_df", [(..., ...)])  # my input args
def test_transform(df, expected_df):
    df_output = transform(df)
    assert df_output == expected_df
  1. Now, how can I mock up .repartition() method for my test ? Something like this pseudo-code (currently not working)
from unittest import mock

@pytest.mark.parametrize("df, expected_df", [(..., ...)])  # my input args
@mock.patch("pyspark.sql.DataFrame.repartition")
def test_transform(df, expected_df):
    df_output = transform(df)
    assert df_output == expected_df

Upvotes: 2

Views: 8681

Answers (1)

Anton
Anton

Reputation: 970

Please chain calls like below. See here similar one

@mock.patch("pyspark.sql.DataFrame")
def test_transform(df: Mock):
    expected_df = "expected value"
    df.repartition.return_value.groupby.return_value.sum.return_value = expected_df
    df_output = transform(df)
    assert df_output == expected_df
    df.repartition.assert_called_with("id")
    df.repartition().groupby.assert_called_with("id")
    df.repartition().groupby().sum.assert_called_with("quantity")

Upvotes: 4

Related Questions