Reputation: 329
I need to write some custum code using multiple columns within a group of my data.
My custom code is to set a flag if a value is over a threshold, but suppress the flag if it is within a certain time of a previous flag.
Here is some sample code:
df = spark.createDataFrame(
[
("a", 1, 0),
("a", 2, 1),
("a", 3, 1),
("a", 4, 1),
("a", 5, 1),
("a", 6, 0),
("a", 7, 1),
("a", 8, 1),
("b", 1, 0),
("b", 2, 1)
],
["group_col","order_col", "flag_col"]
)
df.show()
+---------+---------+--------+
|group_col|order_col|flag_col|
+---------+---------+--------+
| a| 1| 0|
| a| 2| 1|
| a| 3| 1|
| a| 4| 1|
| a| 5| 1|
| a| 6| 0|
| a| 7| 1|
| a| 8| 1|
| b| 1| 0|
| b| 2| 1|
+---------+---------+--------+
from pyspark.sql.functions import udf, col, asc
from pyspark.sql.window import Window
def _suppress(dates=None, alert_flags=None, window=2):
sup_alert_flag = alert_flag
last_alert_date = None
for i, alert_flag in enumerate(alert_flag):
current_date = dates[i]
if alert_flag == 1:
if not last_alert_date:
sup_alert_flag[i] = 1
last_alert_date = current_date
elif (current_date - last_alert_date) > window:
sup_alert_flag[i] = 1
last_alert_date = current_date
else:
sup_alert_flag[i] = 0
else:
alert_flag = 0
return sup_alert_flag
suppress_udf = udf(_suppress, DoubleType())
df_out = df.withColumn("supressed_flag_col", suppress_udf(dates=col("order_col"), alert_flags=col("flag_col"), window=4).Window.partitionBy(col("group_col")).orderBy(asc("order_col")))
df_out.show()
The above fails, but my expected output is the following:
+---------+---------+--------+------------------+
|group_col|order_col|flag_col|supressed_flag_col|
+---------+---------+--------+------------------+
| a| 1| 0| 0|
| a| 2| 1| 1|
| a| 3| 1| 0|
| a| 4| 1| 0|
| a| 5| 1| 0|
| a| 6| 0| 0|
| a| 7| 1| 1|
| a| 8| 1| 0|
| b| 1| 0| 0|
| b| 2| 1| 1|
+---------+---------+--------+------------------+
Upvotes: 1
Views: 1042
Reputation: 341
Editing answer after more thought.
The general problem seems to be that the result of the current row depends upon result of the previous row. In effect, there is a recurrence relationship. I haven't found a good way to implement a recursive UDF in Spark. There are several challenges that result from the assumed distributed nature of the data in Spark which would make this difficult to achieve. At least in my mind. The following solution should work but may not scale for large data sets.
from pyspark.sql import Row
import pyspark.sql.functions as F
import pyspark.sql.types as T
suppress_flag_row = Row("order_col", "flag_col", "res_flag")
def suppress_flag( date_alert_flags, window_size ):
sorted_alerts = sorted( date_alert_flags, key=lambda x: x["order_col"])
res_flags = []
last_alert_date = None
for row in sorted_alerts:
current_date = row["order_col"]
aflag = row["flag_col"]
if aflag == 1 and (not last_alert_date or (current_date - last_alert_date) > window_size):
res = suppress_flag_row(current_date, aflag, True)
last_alert_date = current_date
else:
res = suppress_flag_row(current_date, aflag, False)
res_flags.append(res)
return res_flags
in_fields = [T.StructField("order_col", T.IntegerType(), nullable=True )]
in_fields.append( T.StructField("flag_col", T.IntegerType(), nullable=True) )
out_fields = in_fields
out_fields.append(T.StructField("res_flag", T.BooleanType(), nullable=True) )
out_schema = T.StructType(out_fields)
suppress_udf = F.udf(suppress_flag, T.ArrayType(out_schema) )
window_size = 4
tmp = df.groupBy("group_col").agg( F.collect_list( F.struct( F.col("order_col"), F.col("flag_col") ) ).alias("date_alert_flags"))
tmp2 = tmp.select(F.col("group_col"), suppress_udf(F.col("date_alert_flags"), F.lit(window_size)).alias("suppress_res"))
expand_fields = [F.col("group_col")] + [F.col("res_expand")[f.name].alias(f.name) for f in out_fields]
final_df = tmp2.select(F.col("group_col"), F.explode(F.col("suppress_res")).alias("res_expand")).select( expand_fields )
Upvotes: 1
Reputation: 5880
I think, You don't need custom function for this. you can use rowsBetween option along with window to get the 5 rows range. Please check and let me know if missed something.
>>> from pyspark.sql import functions as F
>>> from pyspark.sql import Window
>>> w = Window.partitionBy('group_col').orderBy('order_col').rowsBetween(-5,-1)
>>> df = df.withColumn('supr_flag_col',F.when(F.sum('flag_col').over(w) == 0,1).otherwise(0))
>>> df.orderBy('group_col','order_col').show()
+---------+---------+--------+-------------+
|group_col|order_col|flag_col|supr_flag_col|
+---------+---------+--------+-------------+
| a| 1| 0| 0|
| a| 2| 1| 1|
| a| 3| 1| 0|
| b| 1| 0| 0|
| b| 2| 1| 1|
+---------+---------+--------+-------------+
Upvotes: 0