DMe
DMe

Reputation: 7876

Unable to pass class object to PySpark UDF

I am trying to pass a custom Python class object to a UDF in PySpark. I do not want a new instance of the object created for every row that it processes since it needs to make an expensive API call to get a secret key. My thinking is to first make the API call when instantiating the object, and then pass that object to tasks. Ideally all executors make use of the same object / a copy of it.

I also make use of an external library whose object is not serializable. It's less of a concern if this has to be instantiated multiple times.

The class looks like this:

class MyClass(object):
    def __init__(self, arg1):
        self.secret = some_api_call(arg1)
        self.third_party_obj = None
    
    def set_3rd_party_obj(self):
        self.third_party_obj = third_party_lib(self.secret)

    def do_thing(self, val):
        return self.third_party_obj(val)

In PySpark, this is what I am attempting:

my_obj = MyClass("arg1")
my_udf = udf(lambda a, b: b.do_thing(a))

df = spark.read.parquet(inputUri)
df = df.withColumn("col1", my_udf(col("col2"), lit(my_obj)))

However, I get AttributeError: 'MyClass' object has no attribute '_get_object_id'. If I try to broadcast my_obj, I get AttributeError: 'Broadcast' object has no attribute '_get_object_id' (trace below).

What does work is if I make the call for the secret outside, and then instantiate a new object in the UDF and pass that in (modifying it so that set_3rd_party_obj is called in the init). However, I want to keep the secret abstracted away in this class. I split set_3rd_party_obj out (not called in init) in the hopes that I could check whether it's been initialized in the UDF before initializing it again to avoid repeated work. At this stage I haven't even got that far since just passing an object with a couple of standard typed variables is throwing an error.

I'd be grateful for any pointers you could give either around how to pass the object to the UDF successfully or if there's a better way to accomplish this.

Stack trace:

    my_udf(col("col2"), lit(my_obj))
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/functions.py", line 44, in _
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1248, in __call__
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1218, in _build_args
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1218, in <listcomp>
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 298, in get_command_part
AttributeError: 'Broadcast' object has no attribute '_get_object_id'

Upvotes: 0

Views: 2303

Answers (2)

DMe
DMe

Reputation: 7876

Ended up resolving this like so by first creating a broadcast object, and using that in the UDF:

my_obj = MyClass("arg1")
my_obj_broadcast = spark.sparkContext.broadcast(my_obj)
my_udf = udf(lambda a, b: my_obj_broadcast.value.do_thing(a, b)

df = spark.read.parquet(inputUri)
df = df.withColumn("col1", my_udf(col("col2"), lit(my_obj)))

Upvotes: 0

Vaebhav
Vaebhav

Reputation: 5062

You can utilise partial for this , the idea is to wrap the partial function call with the argument you intend to use , which is further wrapped around udf API.

I have demonstrated a simple example by utilising your structure , however if you are interested in storing the object returned from each do_thing , you can look into StructField to create the necessary datatype from each function call.

Data Preparation

df = pd.read_csv(StringIO("""
a|b|c
1|3|p
2|4|q
3|4|r
4|7|s
"""),delimiter='|')

sparkDF = sql.createDataFrame(df)

sparkDF.show()

+---+---+---+
|  a|  b|  c|
+---+---+---+
|  1|  3|  p|
|  2|  4|  q|
|  3|  4|  r|
|  4|  7|  s|
+---+---+---+

Class Obj Definition

Note - I have modified the function definitions , to demonstrate a working example

class MyClass(object):
    def __init__(self, arg1):
        self.secret = arg1 # some_api_call(arg1)
        self.third_party_obj = None
    
    def set_3rd_party_obj(self):
        self.third_party_obj = third_party_lib(self.secret)

    def third_party_obj_func(self,inp):
        
        if inp == 2:
            self.third_party_obj = inp
            return "Success"
        else:
            return "Fail"
    
    def do_thing(self, val):
        
        status = self.third_party_obj_func(val)
        
        return status
    
my_obj = MyClass("arg1")

def do_thing_global(a,obj=None):
    
    return obj.do_thing(a)

my_udf = F.udf(partial(do_thing_global,obj=my_obj),StringType())

Partial UDF

sparkDF = sparkDF.withColumn("col1", my_udf(F.col("a")))

sparkDF.show()

+---+---+---+-------+
|  a|  b|  c|   col1|
+---+---+---+-------+
|  1|  3|  p|   Fail|
|  2|  4|  q|Success|
|  3|  4|  r|   Fail|
|  4|  7|  s|   Fail|
+---+---+---+-------+

As the question involves custom function calls , the above approach tries to outline the steps you can follow to achieve your use case

Upvotes: 1

Related Questions