prk
prk

Reputation: 329

how can I create a pyspark udf using multiple columns?

I need to write some custum code using multiple columns within a group of my data.

My custom code is to set a flag if a value is over a threshold, but suppress the flag if it is within a certain time of a previous flag.

Here is some sample code:

df = spark.createDataFrame(
    [
        ("a", 1, 0),
        ("a", 2, 1),
        ("a", 3, 1),
        ("a", 4, 1),
        ("a", 5, 1),
        ("a", 6, 0),
        ("a", 7, 1),
        ("a", 8, 1),
        ("b", 1, 0),
        ("b", 2, 1)
    ],
    ["group_col","order_col", "flag_col"]
)
df.show()
+---------+---------+--------+
|group_col|order_col|flag_col|
+---------+---------+--------+
|        a|        1|       0|
|        a|        2|       1|
|        a|        3|       1|
|        a|        4|       1|
|        a|        5|       1|
|        a|        6|       0|
|        a|        7|       1|
|        a|        8|       1|
|        b|        1|       0|
|        b|        2|       1|
+---------+---------+--------+

from pyspark.sql.functions import udf, col, asc
from pyspark.sql.window import Window
def _suppress(dates=None, alert_flags=None, window=2):
    sup_alert_flag = alert_flag
    last_alert_date = None
    for i, alert_flag in enumerate(alert_flag):
        current_date = dates[i]
        if alert_flag == 1:
            if not last_alert_date:
                sup_alert_flag[i] = 1
                last_alert_date = current_date
            elif (current_date - last_alert_date) > window:
                sup_alert_flag[i] = 1
                last_alert_date = current_date
            else:
                sup_alert_flag[i] = 0
        else:
            alert_flag = 0
    return sup_alert_flag

suppress_udf = udf(_suppress, DoubleType())

df_out = df.withColumn("supressed_flag_col", suppress_udf(dates=col("order_col"), alert_flags=col("flag_col"), window=4).Window.partitionBy(col("group_col")).orderBy(asc("order_col")))

df_out.show()

The above fails, but my expected output is the following:

+---------+---------+--------+------------------+
|group_col|order_col|flag_col|supressed_flag_col|
+---------+---------+--------+------------------+
|        a|        1|       0|                 0|
|        a|        2|       1|                 1|
|        a|        3|       1|                 0|
|        a|        4|       1|                 0|
|        a|        5|       1|                 0|
|        a|        6|       0|                 0|
|        a|        7|       1|                 1|
|        a|        8|       1|                 0|
|        b|        1|       0|                 0|
|        b|        2|       1|                 1|
+---------+---------+--------+------------------+

Upvotes: 1

Views: 1042

Answers (2)

putnampp
putnampp

Reputation: 341

Editing answer after more thought.

The general problem seems to be that the result of the current row depends upon result of the previous row. In effect, there is a recurrence relationship. I haven't found a good way to implement a recursive UDF in Spark. There are several challenges that result from the assumed distributed nature of the data in Spark which would make this difficult to achieve. At least in my mind. The following solution should work but may not scale for large data sets.

from pyspark.sql import Row
import pyspark.sql.functions as F
import pyspark.sql.types as T

suppress_flag_row = Row("order_col", "flag_col", "res_flag")

def suppress_flag( date_alert_flags, window_size ):

    sorted_alerts = sorted( date_alert_flags, key=lambda x: x["order_col"])

    res_flags = []
    last_alert_date = None
    for row in sorted_alerts:
        current_date = row["order_col"]
        aflag = row["flag_col"]
        if aflag == 1 and (not last_alert_date or (current_date - last_alert_date) > window_size):
            res = suppress_flag_row(current_date, aflag, True)
            last_alert_date = current_date
        else:
            res = suppress_flag_row(current_date, aflag, False)

        res_flags.append(res)
    return res_flags

in_fields = [T.StructField("order_col", T.IntegerType(), nullable=True )]
in_fields.append( T.StructField("flag_col", T.IntegerType(), nullable=True) )

out_fields = in_fields
out_fields.append(T.StructField("res_flag", T.BooleanType(), nullable=True) )
out_schema = T.StructType(out_fields)
suppress_udf = F.udf(suppress_flag, T.ArrayType(out_schema) )

window_size = 4
tmp = df.groupBy("group_col").agg( F.collect_list( F.struct( F.col("order_col"), F.col("flag_col") ) ).alias("date_alert_flags"))
tmp2 = tmp.select(F.col("group_col"), suppress_udf(F.col("date_alert_flags"), F.lit(window_size)).alias("suppress_res"))

expand_fields = [F.col("group_col")] + [F.col("res_expand")[f.name].alias(f.name) for f in out_fields]
final_df = tmp2.select(F.col("group_col"), F.explode(F.col("suppress_res")).alias("res_expand")).select( expand_fields )

Upvotes: 1

Suresh
Suresh

Reputation: 5880

I think, You don't need custom function for this. you can use rowsBetween option along with window to get the 5 rows range. Please check and let me know if missed something.

>>> from pyspark.sql import functions as F
>>> from pyspark.sql import Window

>>> w = Window.partitionBy('group_col').orderBy('order_col').rowsBetween(-5,-1)
>>> df = df.withColumn('supr_flag_col',F.when(F.sum('flag_col').over(w) == 0,1).otherwise(0))
>>> df.orderBy('group_col','order_col').show()
+---------+---------+--------+-------------+
|group_col|order_col|flag_col|supr_flag_col|
+---------+---------+--------+-------------+
|        a|        1|       0|            0|
|        a|        2|       1|            1|
|        a|        3|       1|            0|
|        b|        1|       0|            0|
|        b|        2|       1|            1|
+---------+---------+--------+-------------+

Upvotes: 0

Related Questions