Shakile
Shakile

Reputation: 343

group the dataset rows based on the value

I need the group rows of my dataset based on some criteria. My input dataset is like df1 :

+------------+----------+----------------------------+------------------+-------+
|     col_1  |     col_2|col_3                       |            tp    |  range|
+------------+----------+----------------------------+------------------+-------+
|MP          |W         |                           X|10                |]0,3]  |
|MP          |W         |                           X|20                |]12,30]|
|MP          |W         |                           X|18                |]12,30]|
|MP          |W         |                           X|18                |]0,3]  |
|MP          |W         |                           X|30                |]0,3]  |
|MP          |W         |                           X|50                |]12,30]|
|MP          |W         |                           X|18                |]12,30]|
|MP          |W         |                           X|60                |]12,30]|
|MP          |W         |                           X|50                |]12,30]|
|MP          |W         |                           X|70                |]12,30]|
|MP          |W         |                           X|18                |]12,30]|
|MP          |W         |                           X|90                |]12,30]|
|MP          |W         |                           X|18                |]36,48]|
|MP          |W         |                           X|18                |]36,48]|
|MP          |W         |                           X|18                |]12,30]|
|MP          |W         |                           X|180               |]12,30]|
|MP          |W         |                           X|18                |]36,48]|
|MP          |W         |                           X|18                |]12,30]|
|MP          |W         |                         S2E|19                |]24,36]|
|MP          |W         |                         S2E|40                |]24,36]|
+------------+----------+----------------------------+------------------+-------+

What I would like to do is :

  1. Group rows of df1 by range (last column) [df = df1.select("*").groupby("col_1", "col_2", "col_3", "tp", "range"]
  2. for rows in the same range, create subgroups where the ratio between 2 yield (column name = tp) of the same subgroup is less than 2 [i.e tp(i-1)/tp(i) < 2 or tp(i-2)/tp(i) < 2]

In output for the range ]12,30], I will have something like :

+------------+----------+----------------------------+------------------+-------+------------+
|     col_1  |     col_2|col_3                       |            tp    |  range|  subgroup  |
+------------+----------+----------------------------+------------------+-------+------------+
|MP          |W         |                           X|20                |]12,30]|subgroup_1  |
|MP          |W         |                           X|18                |]12,30]|subgroup_1  |
|MP          |W         |                           X|50                |]12,30]|subgroup_2  |
|MP          |W         |                           X|18                |]12,30]|subgroup_1  |
|MP          |W         |                           X|60                |]12,30]|subgroup_2  |
|MP          |W         |                           X|50                |]12,30]|subgroup_2  |
|MP          |W         |                           X|70                |]12,30]|subgroup_2  |
|MP          |W         |                           X|90                |]12,30]|subgroup_2  |
|MP          |W         |                           X|180               |]12,30]|subgroup_3  |
+------------+----------+----------------------------+------------------+-------+------------+

Someone has a solution? I am working in Spark Java.

Upvotes: 0

Views: 287

Answers (1)

memoryz
memoryz

Reputation: 494

First of all, the columns col_1, col_2, col_3, and range don't matter. They can be abstracted away by a group column.

The idea is to use a window function to order rows in each window by tp value, then:

  1. Create a row number for each row which will be used as subgroup id later.
  2. Calculate the ratio between each row and its previous row
  3. If the ratio is greater than or equal to 2, use the current row's row number as the subgroup id; otherwise, carry over the subgroup id from previous row.

Code in scala, but should demonstrate the idea:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val data = Seq(
  (1, 20),
  (1, 18),
  (1, 50),
  (1, 18),
  (1, 60),
  (1, 50),
  (1, 70),
  (1, 90),
  (1, 180),
  (1, 360)
) toDF ("group", "tp")

val windowSpec = Window.partitionBy($"group").orderBy($"tp")
val df = data
  .withColumn("lag_tp", lag($"tp", 1, 0).over(windowSpec))
  .withColumn("row_num", row_number.over(windowSpec))
  .withColumn("reci_yield", $"lag_tp" / $"tp")
  .withColumn("yield_ge_2", $"reci_yield" <= 0.5)
  .withColumn("subGroup", 
                // When yield >= 2 detected, get the current row number as subGroup id
                when($"yield_ge_2" === true, $"row_num") 
                .otherwise(
                  // otherwise, get the last non-null subGroup id.
                   last(
                     when($"yield_ge_2"===true, $"row_num"), 
                     ignoreNulls = true
                   ).over(windowSpec)
                )
             )
  // drop intermediate columns
  .drop("row_num", "lag_tp", "reci_yield", "yield_ge_2")

df.show(false)

Output:

+-----+---+--------+
|group|tp |subGroup|
+-----+---+--------+
|1    |18 |1       |
|1    |18 |1       |
|1    |20 |1       |
|1    |50 |4       |
|1    |50 |4       |
|1    |60 |4       |
|1    |70 |4       |
|1    |90 |4       |
|1    |180|9       |
|1    |360|10      |
+-----+---+--------+

Credit: https://stackoverflow.com/a/65373636/3546203

Upvotes: 1

Related Questions