Sid
Sid

Reputation: 13

Apply UDFs to pyspark dataframe based on row value

I have a pyspark dataframe with the following schema

+-----------+---------+----------+-----------+
|     userID|grouping1| grouping2|   features|
+-----------+---------+----------+-----------+
|12462563356|        1|        A | [5.0,43.0]|
|12462563701|        2|        A |  [1.0,8.0]|
|12462563701|        1|        B | [2.0,12.0]|
|12462564356|        1|        C |  [1.0,1.0]|
|12462565487|        3|        C |  [2.0,3.0]|
|12462565698|        2|        D |  [1.0,1.0]|
|12462565698|        1|        E |  [1.0,1.0]|
|12462566081|        2|        C |  [1.0,2.0]|
|12462566081|        1|        D | [1.0,15.0]|
|12462566225|        2|        E |  [1.0,1.0]|
|12462566225|        1|        A | [9.0,85.0]|
|12462566526|        2|        C |  [1.0,1.0]|
|12462566526|        1|        D | [3.0,79.0]|
|12462567006|        2|        D |[11.0,15.0]|
|12462567006|        1|        B |[10.0,15.0]|
|12462567006|        3|        A |[10.0,15.0]|
|12462586595|        2|        B | [2.0,42.0]|
|12462586595|        3|        D | [2.0,16.0]|
|12462589343|        3|        E |  [1.0,1.0]|
+-----------+---------+----------+-----------+

For values in grouping2 A, B, C and D I need to apply UDF_A, UDF_B, UDF_C and UDF_D respectively. Is there a way I can write something along the lines of

dataset = dataset.withColumn('outputColName', selectUDF(**params))

where selectUDF is defined as

def selectUDF(**params):
    if row[grouping2] == A:
        return UDF_A(**params)
    elif row[grouping2] == B:
        return UDF_B(**params)
    elif row[grouping2] == C:
        return UDF_C(**params)
    elif row[grouping2] == D:
        return UDF_D(**params)

Using the following example to illustrate what I'm trying to do Yes i thought so too. I'm using the following toy code to check this

>>> df = sc.parallelize([[1,2,3], [2,3,4]]).toDF(("a", "b", "c"))
>>> df.show()
+---+---+---+
|  a|  b|  c|
+---+---+---+
|  1|  2|  3|
|  2|  3|  4|
+---+---+---+
>>> def udf1(col):
...     return col1*col1
...
>>> def udf2(col):
...     return col2*col2*col2
...
>>> def select_udf(col1, col2):
...     if col1 == 2:
...         return udf1(col2)
...     elif col1 == 3:
...         return udf2(col2)
...     else:
...         return 0
...
>>> from pyspark.sql.functions import col
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType
>>> select_udf = udf(select_udf, IntegerType())
>>> udf1 = udf(udf1, IntegerType())
>>> udf2 = udf(udf2, IntegerType())
>>> df.withColumn("outCol", select_udf(col("b"), col("c"))).show()
[Stage 9:============================================>              (3 + 1) / 4]

This seems to be stuck at this stage forever. Can anyone suggest what might be wrong here?

Upvotes: 1

Views: 1452

Answers (1)

blackbishop
blackbishop

Reputation: 32700

You don't need a selectUDF, simply use when expression to apply the desired udf depending on the value of grouping2 column:

from pyspark.sql.functions import col, when

df = df.withColumn(
    "outCol",
    when(col("grouping2") == "A", UDF_A(*params))
        .when(col("grouping2") == "B", UDF_B(*params))
        .when(col("grouping2") == "C", UDF_C(*params))
        .when(col("grouping2") == "D", UDF_D(*params))
)

Upvotes: 1

Related Questions