Reputation: 13
I have a pyspark dataframe with the following schema
+-----------+---------+----------+-----------+
| userID|grouping1| grouping2| features|
+-----------+---------+----------+-----------+
|12462563356| 1| A | [5.0,43.0]|
|12462563701| 2| A | [1.0,8.0]|
|12462563701| 1| B | [2.0,12.0]|
|12462564356| 1| C | [1.0,1.0]|
|12462565487| 3| C | [2.0,3.0]|
|12462565698| 2| D | [1.0,1.0]|
|12462565698| 1| E | [1.0,1.0]|
|12462566081| 2| C | [1.0,2.0]|
|12462566081| 1| D | [1.0,15.0]|
|12462566225| 2| E | [1.0,1.0]|
|12462566225| 1| A | [9.0,85.0]|
|12462566526| 2| C | [1.0,1.0]|
|12462566526| 1| D | [3.0,79.0]|
|12462567006| 2| D |[11.0,15.0]|
|12462567006| 1| B |[10.0,15.0]|
|12462567006| 3| A |[10.0,15.0]|
|12462586595| 2| B | [2.0,42.0]|
|12462586595| 3| D | [2.0,16.0]|
|12462589343| 3| E | [1.0,1.0]|
+-----------+---------+----------+-----------+
For values in grouping2
A
, B
, C
and D
I need to apply UDF_A
, UDF_B
, UDF_C
and UDF_D
respectively. Is there a way I can write something along the lines of
dataset = dataset.withColumn('outputColName', selectUDF(**params))
where selectUDF is defined as
def selectUDF(**params):
if row[grouping2] == A:
return UDF_A(**params)
elif row[grouping2] == B:
return UDF_B(**params)
elif row[grouping2] == C:
return UDF_C(**params)
elif row[grouping2] == D:
return UDF_D(**params)
Using the following example to illustrate what I'm trying to do Yes i thought so too. I'm using the following toy code to check this
>>> df = sc.parallelize([[1,2,3], [2,3,4]]).toDF(("a", "b", "c"))
>>> df.show()
+---+---+---+
| a| b| c|
+---+---+---+
| 1| 2| 3|
| 2| 3| 4|
+---+---+---+
>>> def udf1(col):
... return col1*col1
...
>>> def udf2(col):
... return col2*col2*col2
...
>>> def select_udf(col1, col2):
... if col1 == 2:
... return udf1(col2)
... elif col1 == 3:
... return udf2(col2)
... else:
... return 0
...
>>> from pyspark.sql.functions import col
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType
>>> select_udf = udf(select_udf, IntegerType())
>>> udf1 = udf(udf1, IntegerType())
>>> udf2 = udf(udf2, IntegerType())
>>> df.withColumn("outCol", select_udf(col("b"), col("c"))).show()
[Stage 9:============================================> (3 + 1) / 4]
This seems to be stuck at this stage forever. Can anyone suggest what might be wrong here?
Upvotes: 1
Views: 1452
Reputation: 32700
You don't need a selectUDF
, simply use when
expression to apply the desired udf depending on the value of grouping2
column:
from pyspark.sql.functions import col, when
df = df.withColumn(
"outCol",
when(col("grouping2") == "A", UDF_A(*params))
.when(col("grouping2") == "B", UDF_B(*params))
.when(col("grouping2") == "C", UDF_C(*params))
.when(col("grouping2") == "D", UDF_D(*params))
)
Upvotes: 1