Reputation: 305
I have data that looks like this:
id,start,expiration,customerid,content
1,13494,17358,0001,whateveriwanthere
2,14830,28432,0001,somethingelsewoo
3,11943,19435,0001,yes
4,39271,40231,0002,makingfakedata
5,01321,02143,0002,morefakedata
In the data above, I want to group by customerid
for overlapping start
and expiration
(essentially just merge intervals). I am doing this successfully by grouping by the customer id, then aggregating on a first("start")
and max("expiration")
.
df.groupBy("customerid").agg(first("start"), max("expiration"))
However, this drops the id
column entirely. I want to save the id
of the row that had the max expiration. For instance, I want my output to look like this:
id,start,expiration,customerid
2,11934,28432,0001
4,39271,40231,0002
5,01321,02143,0002
I am not sure how to add that id
column for whichever row had the maximum expiration.
Upvotes: 2
Views: 736
Reputation: 32660
You can use a cumulative conditional sum along with lag
function to define group
column that flags rows that overlap. Then, simply group by customerid
+ group
and get min start
and max expiration
. To get the id
value associated with max expiration date, you can use this trick with struct ordering:
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy("customerid").orderBy("start")
val result = df.withColumn(
"group",
sum(
when(
col("start").between(lag("start", 1).over(w), lag("expiration", 1).over(w)),
0
).otherwise(1)
).over(w)
).groupBy("customerid", "group").agg(
min(col("start")).as("start"),
max(struct(col("expiration"), col("id"))).as("max")
).select("max.id", "customerid", "start", "max.expiration")
result.show
//+---+----------+-----+----------+
//| id|customerid|start|expiration|
//+---+----------+-----+----------+
//| 5| 0002|01321| 02143|
//| 4| 0002|39271| 40231|
//| 2| 0001|11943| 28432|
//+---+----------+-----+----------+
Upvotes: 2