Reputation: 1079
Some of these transform functions respect the mock and others don't and I don't know why.
Here's a file called etl_job.py
. It contains various transform functions that use rdd.map
to add a column to a DataFrame
using the get_random_bool
function imported from dependencies.utils
.
# etl_job.py
from dependencies.utils import get_random_bool
def transform_data1(df):
return df.rdd.map(lambda row: row+(get_random_bool(),)).toDF()
def transform_data2(df):
g = lambda row: row+(get_random_bool(),)
return df.rdd.map(lambda row: g(row)).toDF()
def transform_data3(df):
g = lambda row: row+(get_random_bool(),)
h = lambda row: g(row)
return df.rdd.map(lambda row: h(row)).toDF()
def transform_data4(df):
return df.rdd.map(lambda row: f(row)).toDF()
def transform_data5(df):
g = lambda row: f(row)
return df.rdd.map(lambda row: g(row)).toDF()
def f(row):
return row+(get_random_bool(),)
This test file tries to patch the get_random_bool
function that is imported in jobs.etl_job.py
.
from pyspark.sql import SparkSession
from unittest.mock import patch
from jobs.etl_job import transform_data1, transform_data2, transform_data3, transform_data4, transform_data5
# Create SparkSession
spark = SparkSession.builder \
.master('local[1]') \
.appName('Time Tests') \
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
df = spark.createDataFrame([["1"],["2"]])
print('original dataframe')
df.show()
with patch('jobs.etl_job.get_random_bool') as f:
f.return_value = 'notabool'
df_t = transform_data1(df)
print('with transform_data1')
df_t.show()
df_t = transform_data2(df)
print('with transform_data2')
df_t.show()
df_t = transform_data3(df)
print('with transform_data3')
df_t.show()
df_t = transform_data4(df)
print('with transform_data4')
df_t.show()
df_t = transform_data5(df)
print('with transform_data5')
df_t.show()
Here's the output.
original dataframe
+---+
| _1|
+---+
| 1|
| 2|
+---+
with transform_data1
+---+--------+
| _1| _2|
+---+--------+
| 1|notabool|
| 2|notabool|
+---+--------+
with transform_data2
+---+--------+
| _1| _2|
+---+--------+
| 1|notabool|
| 2|notabool|
+---+--------+
with transform_data3
+---+--------+
| _1| _2|
+---+--------+
| 1|notabool|
| 2|notabool|
+---+--------+
with transform_data4
+---+----+
| _1| _2|
+---+----+
| 1|true|
| 2|true|
+---+----+
with transform_data5
+---+-----+
| _1| _2|
+---+-----+
| 1|false|
| 2|false|
+---+-----+
The first three transforms work fine--adding the mocked column where all values are notabool
--but the fourth and fifth don't--they add a column of booleans instead of the mocked value. If the mocked function is called within the transform function, or if the mocked function is called by an inner function of the transform function, then it works; but if the transform function calls a function outside which calls the mocked function, then the actual non-mocked function is used.
Can anyone explain this behavior?
Upvotes: 2
Views: 421
Reputation: 1145
I know this is an old question, but this answer may help others who run into the same issue.
I've spent a lot of time puzzled by this one. It turns out, for some reason, mocking doesn't work well with Spark parallelized threads. Remember that the lambda inside the map()
function runs on the executor, not the driver, i.e. on a different thread/process. So the mocking has to happen within the context of the executor thread, not the driver thread.
The only way I got this to work properly is to wrap the lambda in a function that can be mocked, then in the test, mock that function with one that does the patching and calls the original function (which should happen in the context of the worker thread).
# etl_job.py
from dependencies.utils import get_random_bool
def transform_row(row):
# will be called within the context of a worker thread
return f(row)
def transform_data4(df):
return df.rdd.map(transform_row).toDF()
def f(row):
return row+(get_random_bool(),)
Test file:
from pyspark.sql import SparkSession
from unittest.mock import patch
from jobs.etl_job import ..., transform_data4, transform_row
...
def mock_transform_row(row):
# patch in the context of the worker thread
with patch('jobs.etl_job.get_random_bool') as f:
f.return_value = 'notabool'
return transform_row(row) # delegate to the original function
with patch('jobs.etl_job.transform_row', side_effect=mock_tranform_row):
...
df_t = transform_data4(df)
print('with transform_data4')
df_t.show()
Upvotes: 0