Reputation: 7876
I am trying to pass a custom Python class object to a UDF in PySpark. I do not want a new instance of the object created for every row that it processes since it needs to make an expensive API call to get a secret key. My thinking is to first make the API call when instantiating the object, and then pass that object to tasks. Ideally all executors make use of the same object / a copy of it.
I also make use of an external library whose object is not serializable. It's less of a concern if this has to be instantiated multiple times.
The class looks like this:
class MyClass(object):
def __init__(self, arg1):
self.secret = some_api_call(arg1)
self.third_party_obj = None
def set_3rd_party_obj(self):
self.third_party_obj = third_party_lib(self.secret)
def do_thing(self, val):
return self.third_party_obj(val)
In PySpark, this is what I am attempting:
my_obj = MyClass("arg1")
my_udf = udf(lambda a, b: b.do_thing(a))
df = spark.read.parquet(inputUri)
df = df.withColumn("col1", my_udf(col("col2"), lit(my_obj)))
However, I get AttributeError: 'MyClass' object has no attribute '_get_object_id'
. If I try to broadcast my_obj, I get AttributeError: 'Broadcast' object has no attribute '_get_object_id'
(trace below).
What does work is if I make the call for the secret outside, and then instantiate a new object in the UDF and pass that in (modifying it so that set_3rd_party_obj
is called in the init). However, I want to keep the secret abstracted away in this class. I split set_3rd_party_obj
out (not called in init) in the hopes that I could check whether it's been initialized in the UDF before initializing it again to avoid repeated work. At this stage I haven't even got that far since just passing an object with a couple of standard typed variables is throwing an error.
I'd be grateful for any pointers you could give either around how to pass the object to the UDF successfully or if there's a better way to accomplish this.
Stack trace:
my_udf(col("col2"), lit(my_obj))
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/functions.py", line 44, in _
File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1248, in __call__
File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1218, in _build_args
File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1218, in <listcomp>
File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 298, in get_command_part
AttributeError: 'Broadcast' object has no attribute '_get_object_id'
Upvotes: 0
Views: 2303
Reputation: 7876
Ended up resolving this like so by first creating a broadcast object, and using that in the UDF:
my_obj = MyClass("arg1")
my_obj_broadcast = spark.sparkContext.broadcast(my_obj)
my_udf = udf(lambda a, b: my_obj_broadcast.value.do_thing(a, b)
df = spark.read.parquet(inputUri)
df = df.withColumn("col1", my_udf(col("col2"), lit(my_obj)))
Upvotes: 0
Reputation: 5062
You can utilise partial for this , the idea is to wrap the partial function call with the argument you intend to use , which is further wrapped around udf
API.
I have demonstrated a simple example by utilising your structure , however if you are interested in storing the object returned from each do_thing
, you can look into StructField to create the necessary datatype from each function call.
df = pd.read_csv(StringIO("""
a|b|c
1|3|p
2|4|q
3|4|r
4|7|s
"""),delimiter='|')
sparkDF = sql.createDataFrame(df)
sparkDF.show()
+---+---+---+
| a| b| c|
+---+---+---+
| 1| 3| p|
| 2| 4| q|
| 3| 4| r|
| 4| 7| s|
+---+---+---+
Note - I have modified the function definitions , to demonstrate a working example
class MyClass(object):
def __init__(self, arg1):
self.secret = arg1 # some_api_call(arg1)
self.third_party_obj = None
def set_3rd_party_obj(self):
self.third_party_obj = third_party_lib(self.secret)
def third_party_obj_func(self,inp):
if inp == 2:
self.third_party_obj = inp
return "Success"
else:
return "Fail"
def do_thing(self, val):
status = self.third_party_obj_func(val)
return status
my_obj = MyClass("arg1")
def do_thing_global(a,obj=None):
return obj.do_thing(a)
my_udf = F.udf(partial(do_thing_global,obj=my_obj),StringType())
sparkDF = sparkDF.withColumn("col1", my_udf(F.col("a")))
sparkDF.show()
+---+---+---+-------+
| a| b| c| col1|
+---+---+---+-------+
| 1| 3| p| Fail|
| 2| 4| q|Success|
| 3| 4| r| Fail|
| 4| 7| s| Fail|
+---+---+---+-------+
As the question involves custom function calls , the above approach tries to outline the steps you can follow to achieve your use case
Upvotes: 1