Georg Heiler
Georg Heiler

Reputation: 17676

spark aggregate set of events per key including their change timestamps

For a dataframe of:

+----+--------+-------------------+----+
|user|      dt|         time_value|item|
+----+--------+-------------------+----+
| id1|20200101|2020-01-01 00:00:00|   A|
| id1|20200101|2020-01-01 10:00:00|   B|
| id1|20200101|2020-01-01 09:00:00|   A|
| id1|20200101|2020-01-01 11:00:00|   B|
+----+--------+-------------------+----+

I want to capture all the unique items i.e. collect_set, but retain its own time_value

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.unix_timestamp
import org.apache.spark.sql.functions.collect_set
import org.apache.spark.sql.types.TimestampType
val timeFormat = "yyyy-MM-dd HH:mm"
val dx = Seq(("id1", "20200101", "2020-01-01 00:00", "A"), ("id1", "20200101","2020-01-01 10:00", "B"), ("id1", "20200101","2020-01-01 9:00", "A"), ("id1", "20200101","2020-01-01 11:00", "B")).toDF("user", "dt","time_value", "item").withColumn("time_value", unix_timestamp(col("time_value"), timeFormat).cast(TimestampType))
dx.show

A

dx.groupBy("user", "dt").agg(collect_set("item")).show
+----+--------+-----------------+                                               
|user|      dt|collect_set(item)|
+----+--------+-----------------+
| id1|20200101|           [B, A]|
+----+--------+-----------------+

does not retain the time_value information when the signal switched from A to B. How can I keep the time value information for each set in the item?

Would it be possible to have the collect_set within a window function to achieve the desired result? Currently, I can only think of:

  1. use a window function to determine pairs of events
  2. filter to change events
  3. aggregate

which needs to shuffle multiple times. Alternatively, a UDF would be possible (collect_list(sort_array(struct(time_value, item)))) but that also seems rather clumsy.

Is there a better way?

Upvotes: 0

Views: 80

Answers (1)

Raphael Roth
Raphael Roth

Reputation: 27373

I would indeed use window-functions to isolate the change-points, I think there are no alternatives:

val win = Window.partitionBy($"user",$"dt").orderBy($"time_value")

dx
.orderBy($"time_value")
.withColumn("item_change_post",coalesce((lag($"item",1).over(win)=!=$"item"),lit(false)))
.withColumn("item_change_pre",lead($"item_change_post",1).over(win))
.where($"item_change_pre" or $"item_change_post")
.show()

+----+--------+-------------------+----+----------------+---------------+
|user|      dt|         time_value|item|item_change_post|item_change_pre|
+----+--------+-------------------+----+----------------+---------------+
| id1|20200101|2020-01-01 09:00:00|   A|           false|           true|
| id1|20200101|2020-01-01 10:00:00|   B|            true|          false|
+----+--------+-------------------+----+----------------+---------------+

then use something like groupBy($"user",$"dt").agg(collect_list(struct($"time_value",$"item")))

I don't think that multiple shuffles occur, because you always partition/group by the same keys.

You can try to make it more efficient by aggregating your initial dataframe to the min/max time_value for each item, then do the same as above.

Upvotes: 2

Related Questions