Reputation: 110
I have code using pyspark
library and I want to test it with pytest
However, I want to mock up .repartition()
method on dataframes when running tests
def transform(df: pyspark.sql.DataFrame):
return (
df
.repartition("id")
.groupby("id")
.sum("quantity")
)
@pytest.mark.parametrize("df, expected_df", [(..., ...)]) # my input args
def test_transform(df, expected_df):
df_output = transform(df)
assert df_output == expected_df
.repartition()
method for my test ? Something like this pseudo-code (currently not working)from unittest import mock
@pytest.mark.parametrize("df, expected_df", [(..., ...)]) # my input args
@mock.patch("pyspark.sql.DataFrame.repartition")
def test_transform(df, expected_df):
df_output = transform(df)
assert df_output == expected_df
Upvotes: 2
Views: 8681
Reputation: 970
Please chain calls like below. See here similar one
@mock.patch("pyspark.sql.DataFrame")
def test_transform(df: Mock):
expected_df = "expected value"
df.repartition.return_value.groupby.return_value.sum.return_value = expected_df
df_output = transform(df)
assert df_output == expected_df
df.repartition.assert_called_with("id")
df.repartition().groupby.assert_called_with("id")
df.repartition().groupby().sum.assert_called_with("quantity")
Upvotes: 4