Reputation: 207
I am new to pyspark, and I am trying to use a udf to map some string names. I have to map some data values to new names, so I was going to send the column value from sparkdf, and dictionary of mapped fields to a udf, instead of writing a ton of .when()
's after .withColumn()
.
Tried passing just 2 strings to the udf, and it works, but passing the dictionary doesn't.
def stringToStr_function(checkCol, dict1) :
for key, value in dict1.iteritems() :
if(checkCol != None and checkCol==key): return value
stringToStr_udf = udf(stringToStr_function, StringType())
df = sparkdf.toDF().withColumn(
"new_col",
stringToStr_udf(
lit("REQUEST"),
{"REQUEST": "Requested", "CONFIRM": "Confirmed", "CANCEL": "Cancelled"}
)
)
But getting this error about method col() does not exist. any ideas?:
File "<stdin>", line 2, in <module>
File "/usr/lib/spark/python/pyspark/sql/functions.py", line 1957, in wrapper
return udf_obj(*args)
File "/usr/lib/spark/python/pyspark/sql/functions.py", line 1918, in __call__
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
File "/usr/lib/spark/python/pyspark/sql/column.py", line 60, in _to_seq
cols = [converter(c) for c in cols]
File "/usr/lib/spark/python/pyspark/sql/column.py", line 48, in _to_java_column
jcol = _create_column_from_name(col)
File "/usr/lib/spark/python/pyspark/sql/column.py", line 41, in _create_column_from_name
return sc._jvm.functions.col(name)
File "/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py", line 1133, in __call__
answer, self.gateway_client, self.target_id, self.name)
File "/usr/lib/spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py", line 323, in get_return_value
format(target_id, ".", name, value))
Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace:
py4j.Py4JException: Method col([class java.util.HashMap]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339)
at py4j.Gateway.invoke(Gateway.java:274)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:214)
at java.lang.Thread.run(Thread.java:748)
Thanks for any help. I am using aws glue and Python 2.x, and I am testing in a notebook.
Upvotes: 4
Views: 12790
Reputation: 19328
Here's how to solve this with a UDF and a broadcasted dictionary. pault's solution is clever and seems to rely on the auto broadcasting of the dictionary cause it's small. Don't think pault's solution works for a dictionary that's bigger than the autobroadcast limit. Explicitly broadcasting is the safest way to write PySpark code in my opinion. The UDF approach can also be better if the logic cannot be expressed with the native Spark functions.
Suppose you have the following DataFrame.
+-------+
| status|
+-------+
|REQUEST|
|CONFIRM|
+-------+
Here's the code to apply the mapping dictionary to the column.
def stringToStr(dict1_broadcasted):
def f(x):
return dict1_broadcasted.value.get(x)
return F.udf(f)
df = spark.createDataFrame([["REQUEST",], ["CONFIRM",]]).toDF("status")
b = spark.sparkContext.broadcast({"REQUEST": "Requested", "CONFIRM": "Confirmed", "CANCEL": "Cancelled"})
df.withColumn(
"new_col",
stringToStr(b)(F.col("status"))
).show()
+-------+---------+
| status| new_col|
+-------+---------+
|REQUEST|Requested|
|CONFIRM|Confirmed|
+-------+---------+
See this post for more details about all the errors you might encounter when broadcasting dictionaries for PySpark. It's hard to get right, but a powerful technique to have in your toolkit.
Upvotes: 1
Reputation: 43534
As shown in the linked duplicate:
The cleanest solution is to pass additional arguments using closure
However, you don't need a udf
for this particular problem. (See Spark functions vs UDF performance?)
You can use pyspark.sql.functions.when
to implement IF-THEN-ELSE
logic:
from pyspark.sql.functions import coalesce, col, lit, when
def stringToStr_function(checkCol, dict1):
return coalesce(
*[when(col(checkCol) == key, lit(value)) for key, value in dict1.iteritems()]
)
df = sparkdf.withColumn(
"new_col",
stringToStr_function(
checkCol = lit("REQUEST"),
dict1 = {"REQUEST": "Requested", "CONFIRM": "Confirmed", "CANCEL": "Cancelled"}
)
)
We iterate through the items in the dictionary and use when
to return the value
if the value in checkCol
matches the key
. Wrap that in a call to pyspark.sql.functions.coalesce()
which will return the first non-null value.
Upvotes: 3