DoktorZhor
DoktorZhor

Reputation: 13

How to split data into series based on conditions in Apache Spark

I have data in following format sorted by timestamp, each row representing an event:

+----------+--------+---------+
|event_type|  data  |timestamp|
+----------+--------+---------+
|     A    |    d1  |    1    |
|     B    |    d2  |    2    |
|     C    |    d3  |    3    |
|     C    |    d4  |    4    |
|     C    |    d5  |    5    |
|     A    |    d6  |    6    |
|     A    |    d7  |    7    |
|     B    |    d8  |    8    |
|     C    |    d9  |    9    |
|     B    |    d10 |    12   |
|     C    |    d11 |    20   |
+----------+--------+---------+

I need to collect these events into series like so:
1. Event of type C marks the end of the series
2. If there are multiple consecutive events of type C, they fall to the same series and the last one marks the end of that series
3. Each series can span 7 days at max, even if there is no C event to end it

Please also note that there can be multiple series in a single day. In reality, timestamp column are standard UNIX timestamps, here let the numbers express days for simplicity.

So desired output would look like this:

+---------------------+--------------------------------------------------------------------+
|first_event_timestamp|                events: List[(event_type, data,  timestamp)]        |
+---------------------+--------------------------------------------------------------------+
|          1          | List((A, d1, 1), (B, d2, 2), (C, d3, 3),  (C, d4, 4), (C, d5, 5))  |
|          6          | List((A, d6, 6), (A, d7, 7), (B, d8, 8),  (C, d9, 9))              |
|          12         | List((B, d10, 12))                                                 |
|          20         | List((C, d11, 20))                                                 |
+---------------------+--------------------------------------------------------------------+

I tried to solve this using Window functions, where I would add 2 columns like this:
1. Seed column marked event directly succeeding an event of type C using some unique id
2. SeriesId was filled by values from Seed column using last() to mark all events in one series with same id
3. I would then group the events by the SeriesId

Unfortunately, this doesn't seem possible:

+----------+--------+---------+------+-----------+
|event_type|  data  |timestamp| seed | series_id |
+----------+--------+---------+------+-----------+
|     A    |    d1  |    1    | null |    null   |
|     B    |    d2  |    2    | null |    null   |
|     C    |    d3  |    3    | null |    null   |
|     C    |    d4  |    4    |   0  |     0     |     
|     C    |    d5  |    5    |   1  |     1     |
|     A    |    d6  |    6    |   2  |     2     |
|     A    |    d7  |    7    | null |     2     |
|     B    |    d8  |    8    | null |     2     |
|     C    |    d9  |    9    | null |     2     |
|     B    |    d10 |    12   |   3  |     3     |
|     C    |    d11 |    20   | null |     3     |
+----------+--------+---------+------+-----------+
  1. I don't seem to be able to test preceding row on equality using lag(), i.e. following code:
df.withColumn(
    "seed",
    when(
        (lag($"eventType", 1) === ventType.Conversion).over(w), 
        typedLit(DigestUtils.sha256Hex("some fields").substring(0, 32))
    )
)

throws

org.apache.spark.sql.AnalysisException: Expression '(lag(eventType#76, 1, null) = C)' not supported within a window function.

  1. As the table shows, it fails on case where there are multiple consecutive C events and also wouldn't work for the first and last series.

I'm kinda stuck here, any help would be appreciated(using Dataframe/dataset api is prefered).

Upvotes: 1

Views: 199

Answers (1)

Ranga Vure
Ranga Vure

Reputation: 1932

Here is the approach

  1. Identify the start of the event series, based on conditions
  2. Tag the record as start event
  3. select the records of start events
  4. get the records end date (if we order the start event records desc, then previous start time will be current end series time)
  5. join the original data, with above dataset

Here is udf to tag the record as "start"

//tag the starting event, based on the conditions
 def tagStartEvent : (String,String,Int,Int) => String = (prevEvent:String,currEvent:String,prevTimeStamp:Int,currTimeStamp:Int)=>{
   //very first event is tagged as "start"
   if (prevEvent == "start")
     "start"
   else if ((currTimeStamp - prevTimeStamp) > 7 )
     "start"
   else {
     prevEvent match {
       case "C" =>
         if (currEvent == "A")
           "start"
         else if (currEvent == "B")
           "start"
         else // if current event C
           ""
       case _ => ""
     }
   }
 }
val tagStartEventUdf = udf(tagStartEvent)

data.csv

event_type,data,timestamp
A,d1,1
B,d2,2
C,d3,3
C,d4,4
C,d5,5
A,d6,6
A,d7,7
B,d8,8
C,d9,9
B,d10,12
C,d11,20
val df = spark.read.format("csv")
                  .option("header", "true")
                  .option("inferSchema", "true")
                  .load("data.csv")

    val window = Window.partitionBy("all").orderBy("timestamp")

    //tag the starting event
    val dfStart =
        df.withColumn("all", lit(1))
          .withColumn("series_start",
            tagStartEventUdf(
              lag($"event_type",1, "start").over(window), df("event_type"),
              lag($"timestamp",1,1).over(window),df("timestamp")))

    val dfStartSeries = dfStart.filter($"series_start" === "start").select(($"timestamp").as("series_start_time"),$"all")

    val window2 = Window.partitionBy("all").orderBy($"series_start_time".desc)
    //get the series end times
    val dfSeriesTimes = dfStartSeries.withColumn("series_end_time",lag($"series_start_time",1,null).over(window2)).drop($"all")

    val dfSeries =
          df.join(dfSeriesTimes).withColumn("timestamp_series",
              // if series_end_time is null and  timestamp >= series_start_time, then series_start_time
              when(col("series_end_time").isNull && col("timestamp") >= col("series_start_time"), col("series_start_time"))
                // if record greater or equal to series_start_time, and less than series_end_time, then series_start_time
                .otherwise(when((col("timestamp") >= col("series_start_time") && col("timestamp") < col("series_end_time")), col("series_start_time")).otherwise(null)))
                .filter($"timestamp_series".isNotNull)

   dfSeries.show()

Upvotes: 1

Related Questions