user17146820
user17146820

Reputation:

Apache Spark , detect key change on a row and group rows

I have the below dataframe in spark where I need to detect the key change ( on column rec) and create a new column called groupId. For example the first row and second row belong to one group until again the same set of record (D) is encountered and 1st row and 2nd row belong to the same groupId.

rec    amount  date           
D        250     20220522                  
C        110     20220522                  
D        120    20220522                   
C        100    20220522                   
C        50     20220522                   
D        50     20220522                   
D        50     20220522                   
D        50     20220522                   

EXPECTED OUTPUT

rec    amount  date            groupId   
D        250     20220522       1           
C        110     20220522       1           
D        120    20220522        2           
C        100    20220522        2          
C        50     20220522        2           
D        50     20220522        3           
D        50     20220522        4           
D        50     20220522        5           

I tried many ways but couldn't achieve the desired output , I am not sure what I am doing incorrectly here , below is what I have tried

WindowSpec window = Window.orderBy("date");
 Dataset<Row> dataset4 = data

            .withColumn("nextRow", functions.lead("rec", 1).over(window))
            .withColumn("prevRow", functions.lag("rec", 1).over(window))
            .withColumn("groupId",
                functions.when(functions.col("nextRow")
                        .equalTo(functions.col("prevRow")),
                        functions.dense_rank().over(window)
                    ));

Can someone please help me what I am doing incorrectly here ?

Upvotes: 0

Views: 387

Answers (1)

vilalabinot
vilalabinot

Reputation: 1601

Window function does not work quite work like that; here is a workaround that might not be the best one;

First, keep track of what the starting value is:

val different = if (df.rdd.collect()(0)(0) == "C") 1 else 0

We set a value of 0 to C and a value of 1 to D:

.withColumn("other", when(col("rec").equalTo("C"), 0).otherwise(1))

Then, we create a unique id (because we do not have a combination of rows that indicate a unique row):

.withColumn("id", expr("row_number() over (order by date)"))

Finally, we do a cumulative count:

.withColumn("group_id",
  sum("other").over(Window.orderBy("id").partitionBy("date")) + different
)

I partitioned by date here, you can remove that but the performance might degrade seriously. Finally, we drop id, final result:

+---+------+--------+-----+--------+
|rec|amount|date    |other|group_id|
+---+------+--------+-----+--------+
|D  |250   |20220522|1    |1       |
|C  |110   |20220522|0    |1       |
|D  |120   |20220522|1    |2       |
|C  |100   |20220522|0    |2       |
|C  |50    |20220522|0    |2       |
|D  |50    |20220522|1    |3       |
|D  |50    |20220522|1    |4       |
|D  |50    |20220522|1    |5       |
+---+------+--------+-----+--------+

Good luck!

Upvotes: 1

Related Questions