iLikeKFC
iLikeKFC

Reputation: 199

Spark SQL: get the value of a column when another column is max value inside a groupBy().agg()

I have a dataframe that looks like this:

root
 |-- value: int (nullable = true)
 |-- date: date (nullable = true)

I'd like to return value where value is the latest date in the dataframe. Does this problem change if I need to make a groupBy and agg? My actual problem looks like this:

val result = df
.filter(df("date")>= somedate && df("date")<= some other date)
.groupBy(valueFromColumn1)
.agg(
    max(date),
    min(valueFromColumn2),
    Here I want to put valueFromColumn4 where date is max after the filter
 )

I know I can get these values by creating a second dataframe and then making a join. But I'd like to avoid the join operation if possible.

Input sample:

Column 1 | Column 2 | Date | Column 4
    A         1       2006      5
    A         5       2018      2
    A         3       2000      3
    B         13      2007      4

Output sameple (filter is date >= 2006, date <= 2018):

Column 1 | Column 2 | Date | Column 4
    A         1       2018      2  <- I got 2 from the first row which has the highest date
    B         13      2007      4

Upvotes: 0

Views: 2907

Answers (3)

Neha Kumari
Neha Kumari

Reputation: 787

The operation which you want to do is ordering within a group of data(here grouped on Column1). This is perfect use case of windowed function, which does perform calculation over a group of records(window).

Here we can partition window on Column1, and pick the maximum of date from each such window. Let's define windowedPartition as :

val windowedPartition = Window.partitionBy("col1").orderBy(col("date").desc)

Then we can apply this window function on our data set to select the row with the highest rank. (I have not added filtering logic in the code below as I think that is not brining any complexity here and will not affect the solution )

Working code :

    scala> import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.expressions.Window


    scala> val data = Seq(("a" , 1, 2006, 5), ("a", 5, 2018, 2), ("a", 3, 2000, 3), ("b", 13, 2007, 4)).toDF("col1", "col2", "date", "col4")
    data: org.apache.spark.sql.DataFrame = [col1: string, col2: int ... 2 more fields]


    scala> data.show
    +----+----+----+----+
    |col1|col2|date|col4|
    +----+----+----+----+
    |   a|   1|2006|   5|
    |   a|   5|2018|   2|
    |   a|   3|2000|   3|
    |   b|  13|2007|   4|
    +----+----+----+----+      

    scala> val windowedPartition = Window.partitionBy("col1").orderBy(col("date").desc)
    windowedPartition: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@39613474

    scala> data.withColumn("row_number", row_number().over(windowedPartition)).show
    +----+----+----+----+----------+
    |col1|col2|date|col4|row_number|
    +----+----+----+----+----------+
    |   b|  13|2007|   4|         1|
    |   a|   5|2018|   2|         1|
    |   a|   1|2006|   5|         2|
    |   a|   3|2000|   3|         3|
    +----+----+----+----+----------+


    scala> data.withColumn("row_number", row_number().over(windowedPartition)).where(col("row_number") === 1).show
    +----+----+----+----+----------+
    |col1|col2|date|col4|row_number|
    +----+----+----+----+----------+
    |   b|  13|2007|   4|         1|
    |   a|   5|2018|   2|         1|
    +----+----+----+----+----------+


    scala> data.withColumn("row_number", row_number().over(windowedPartition)).where(col("row_number") === 1).drop(col("row_number")).show
    +----+----+----+----+
    |col1|col2|date|col4|
    +----+----+----+----+
    |   b|  13|2007|   4|
    |   a|   5|2018|   2|
    +----+----+----+----+

I believe this will be more scalable solution than struct since if the number of column increases we might have to add those columns as well in struct, in this solution that case will be taken care of.

One question though - In your o/p the value in col2 should be 5(for col1=A) right? How is the value of col2 changing to 1?

Upvotes: 2

Raphael Roth
Raphael Roth

Reputation: 27383

you can use either groupBy with struct :

df
  .groupBy()
  .agg(max(struct($"date",$"value")).as("latest"))
  .select($"latest.*")

or with Window:

df
  .withColumn("rnk",row_number().over(Window.orderBy($"date".desc)))
  .where($"rnk"===1).drop($"rnk")

Upvotes: 3

Oli
Oli

Reputation: 10406

A solution would be to use a struct to bind the value and the date together. It would look like this:

val result = df
  .filter(df("date")>= somedate && df("date")<= some other date)
  .withColumn("s", struct(df("date") as "date", df(valueFromColumn4) as "value"))
  .groupBy(valueFromColumn1)
  .agg(
     // since date is the first value of the struct,
     // this selects the tuple that maximizes date, and the associated value.
     max(col("s")) as "s", 
     min(col(valueFromColumn2)),
  )
  .withColumn("date", col("s.date"))
  .withColumn(valueFromColumn4, col("s.value"))

Upvotes: 4

Related Questions