Ron
Ron

Reputation: 207

Pass a dictionary to pyspark udf

I am new to pyspark, and I am trying to use a udf to map some string names. I have to map some data values to new names, so I was going to send the column value from sparkdf, and dictionary of mapped fields to a udf, instead of writing a ton of .when()'s after .withColumn().

Tried passing just 2 strings to the udf, and it works, but passing the dictionary doesn't.

def stringToStr_function(checkCol, dict1) :
  for key, value in dict1.iteritems() :
    if(checkCol != None and checkCol==key): return value

stringToStr_udf = udf(stringToStr_function, StringType())

df = sparkdf.toDF().withColumn(
    "new_col",
     stringToStr_udf(
        lit("REQUEST"),
        {"REQUEST": "Requested", "CONFIRM": "Confirmed", "CANCEL": "Cancelled"}
     )
)

But getting this error about method col() does not exist. any ideas?:

File "<stdin>", line 2, in <module>
  File "/usr/lib/spark/python/pyspark/sql/functions.py", line 1957, in wrapper
    return udf_obj(*args)
  File "/usr/lib/spark/python/pyspark/sql/functions.py", line 1918, in __call__
    return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
  File "/usr/lib/spark/python/pyspark/sql/column.py", line 60, in _to_seq
    cols = [converter(c) for c in cols]
  File "/usr/lib/spark/python/pyspark/sql/column.py", line 48, in _to_java_column
    jcol = _create_column_from_name(col)
  File "/usr/lib/spark/python/pyspark/sql/column.py", line 41, in _create_column_from_name
    return sc._jvm.functions.col(name)
  File "/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py", line 1133, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/usr/lib/spark/python/pyspark/sql/utils.py", line 63, in deco
    return f(*a, **kw)
  File "/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py", line 323, in get_return_value
    format(target_id, ".", name, value))
Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace:

py4j.Py4JException: Method col([class java.util.HashMap]) does not exist
        at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
        at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339)
        at py4j.Gateway.invoke(Gateway.java:274)
        at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
        at py4j.commands.CallCommand.execute(CallCommand.java:79)
        at py4j.GatewayConnection.run(GatewayConnection.java:214)
        at java.lang.Thread.run(Thread.java:748)

Thanks for any help. I am using aws glue and Python 2.x, and I am testing in a notebook.

Upvotes: 4

Views: 12790

Answers (2)

Powers
Powers

Reputation: 19328

Here's how to solve this with a UDF and a broadcasted dictionary. pault's solution is clever and seems to rely on the auto broadcasting of the dictionary cause it's small. Don't think pault's solution works for a dictionary that's bigger than the autobroadcast limit. Explicitly broadcasting is the safest way to write PySpark code in my opinion. The UDF approach can also be better if the logic cannot be expressed with the native Spark functions.

Suppose you have the following DataFrame.

+-------+
| status|
+-------+
|REQUEST|
|CONFIRM|
+-------+

Here's the code to apply the mapping dictionary to the column.

def stringToStr(dict1_broadcasted):
    def f(x):
        return dict1_broadcasted.value.get(x)
    return F.udf(f)

df = spark.createDataFrame([["REQUEST",], ["CONFIRM",]]).toDF("status")
b = spark.sparkContext.broadcast({"REQUEST": "Requested", "CONFIRM": "Confirmed", "CANCEL": "Cancelled"})
df.withColumn(
    "new_col",
     stringToStr(b)(F.col("status"))
).show()
+-------+---------+
| status|  new_col|
+-------+---------+
|REQUEST|Requested|
|CONFIRM|Confirmed|
+-------+---------+

See this post for more details about all the errors you might encounter when broadcasting dictionaries for PySpark. It's hard to get right, but a powerful technique to have in your toolkit.

Upvotes: 1

pault
pault

Reputation: 43534

As shown in the linked duplicate:

The cleanest solution is to pass additional arguments using closure

However, you don't need a udf for this particular problem. (See Spark functions vs UDF performance?)

You can use pyspark.sql.functions.when to implement IF-THEN-ELSE logic:

from pyspark.sql.functions import coalesce, col, lit, when

def stringToStr_function(checkCol, dict1):
    return coalesce(
        *[when(col(checkCol) == key, lit(value)) for key, value in dict1.iteritems()]
    )

df = sparkdf.withColumn(
    "new_col",
    stringToStr_function(
        checkCol = lit("REQUEST"),
        dict1 = {"REQUEST": "Requested", "CONFIRM": "Confirmed", "CANCEL": "Cancelled"}
    )
)

We iterate through the items in the dictionary and use when to return the value if the value in checkCol matches the key. Wrap that in a call to pyspark.sql.functions.coalesce() which will return the first non-null value.

Upvotes: 3

Related Questions