Reputation: 343
I need the group rows of my dataset based on some criteria. My input dataset is like df1 :
+------------+----------+----------------------------+------------------+-------+
| col_1 | col_2|col_3 | tp | range|
+------------+----------+----------------------------+------------------+-------+
|MP |W | X|10 |]0,3] |
|MP |W | X|20 |]12,30]|
|MP |W | X|18 |]12,30]|
|MP |W | X|18 |]0,3] |
|MP |W | X|30 |]0,3] |
|MP |W | X|50 |]12,30]|
|MP |W | X|18 |]12,30]|
|MP |W | X|60 |]12,30]|
|MP |W | X|50 |]12,30]|
|MP |W | X|70 |]12,30]|
|MP |W | X|18 |]12,30]|
|MP |W | X|90 |]12,30]|
|MP |W | X|18 |]36,48]|
|MP |W | X|18 |]36,48]|
|MP |W | X|18 |]12,30]|
|MP |W | X|180 |]12,30]|
|MP |W | X|18 |]36,48]|
|MP |W | X|18 |]12,30]|
|MP |W | S2E|19 |]24,36]|
|MP |W | S2E|40 |]24,36]|
+------------+----------+----------------------------+------------------+-------+
What I would like to do is :
In output for the range ]12,30], I will have something like :
+------------+----------+----------------------------+------------------+-------+------------+
| col_1 | col_2|col_3 | tp | range| subgroup |
+------------+----------+----------------------------+------------------+-------+------------+
|MP |W | X|20 |]12,30]|subgroup_1 |
|MP |W | X|18 |]12,30]|subgroup_1 |
|MP |W | X|50 |]12,30]|subgroup_2 |
|MP |W | X|18 |]12,30]|subgroup_1 |
|MP |W | X|60 |]12,30]|subgroup_2 |
|MP |W | X|50 |]12,30]|subgroup_2 |
|MP |W | X|70 |]12,30]|subgroup_2 |
|MP |W | X|90 |]12,30]|subgroup_2 |
|MP |W | X|180 |]12,30]|subgroup_3 |
+------------+----------+----------------------------+------------------+-------+------------+
Someone has a solution? I am working in Spark Java.
Upvotes: 0
Views: 287
Reputation: 494
First of all, the columns col_1
, col_2
, col_3
, and range
don't matter. They can be abstracted away by a group
column.
The idea is to use a window function to order rows in each window by tp
value, then:
Code in scala, but should demonstrate the idea:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val data = Seq(
(1, 20),
(1, 18),
(1, 50),
(1, 18),
(1, 60),
(1, 50),
(1, 70),
(1, 90),
(1, 180),
(1, 360)
) toDF ("group", "tp")
val windowSpec = Window.partitionBy($"group").orderBy($"tp")
val df = data
.withColumn("lag_tp", lag($"tp", 1, 0).over(windowSpec))
.withColumn("row_num", row_number.over(windowSpec))
.withColumn("reci_yield", $"lag_tp" / $"tp")
.withColumn("yield_ge_2", $"reci_yield" <= 0.5)
.withColumn("subGroup",
// When yield >= 2 detected, get the current row number as subGroup id
when($"yield_ge_2" === true, $"row_num")
.otherwise(
// otherwise, get the last non-null subGroup id.
last(
when($"yield_ge_2"===true, $"row_num"),
ignoreNulls = true
).over(windowSpec)
)
)
// drop intermediate columns
.drop("row_num", "lag_tp", "reci_yield", "yield_ge_2")
df.show(false)
Output:
+-----+---+--------+
|group|tp |subGroup|
+-----+---+--------+
|1 |18 |1 |
|1 |18 |1 |
|1 |20 |1 |
|1 |50 |4 |
|1 |50 |4 |
|1 |60 |4 |
|1 |70 |4 |
|1 |90 |4 |
|1 |180|9 |
|1 |360|10 |
+-----+---+--------+
Credit: https://stackoverflow.com/a/65373636/3546203
Upvotes: 1