Soumya
Soumya

Reputation: 1913

Calculating column value in current row of Spark Dataframe based on the calculated value of a different column in previous row using Scala

Supposing I have a Dataframe like below

Id A B C D
1 100 10 20 5
2 0 5 10 5
3 0 7 2 3
4 0 1 3 7

And the above needs to be converted to something like below

Id A B C D E
1 100 10 20 5 75
2 75 5 10 5 60
3 60 7 2 3 50
4 50 1 3 7 40

The thing works by the details provided below

  1. The data frame now has a new column E which for row 1 is calculated as col(A) - (max(col(B), col(C)) + col(D)) => 100-(max(10,20) + 5) = 75
  2. In the row with Id 2, the value of col E from row 1 is brough forward as the value for Col A
  3. So, for row 2, the column E, is determined as 75-(max(5,10) + 5) = 60
  4. Similarly in the row with Id 3, the value of A becomes 60 and the new value for col E is determined based on this

The problem is, the value of col A is dependent on the previous row's values except for the first row

Is there a possibility to solve this using windowing and lag

Upvotes: 1

Views: 1914

Answers (2)

Vincent Doba
Vincent Doba

Reputation: 5068

As blackbishop said, you can't use lag function to retrieve changing value of a column. As you're using the scala API, you can develop your own User-Defined Aggregate Function

You create the following case classes, representing the row you're currently reading and your aggregator's buffer:

case class InputRow(A: Integer, B: Integer, C: Integer, D: Integer)

case class Buffer(var E: Integer, var A: Integer)

Then you use them to define your RecursiveAggregator custom aggregator:

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder

object RecursiveAggregator extends Aggregator[InputRow, Buffer, Buffer] {
  override def zero: Buffer = Buffer(null, null)

  override def reduce(buffer: Buffer, currentRow: InputRow): Buffer = {
    buffer.A = if (buffer.E == null) currentRow.A else buffer.E
    buffer.E = buffer.A - (math.max(currentRow.B, currentRow.C) + currentRow.D)
    buffer
  }

  override def merge(b1: Buffer, b2: Buffer): Buffer = {
    throw new NotImplementedError("should be used only over ordered window")
  }

  override def finish(reduction: Buffer): Buffer = reduction

  override def bufferEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]

  override def outputEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]
}

Finally you transform your RecursiveAggregator to an User-Defined aggregate function that you apply on your input dataframe:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, udaf}

val recursiveAggregator = udaf(RecursiveAggregator)

val window = Window.orderBy("Id")

val result = input
  .withColumn("computed", recursiveAggregator(col("A"), col("B"), col("C"), col("D")).over(window))
  .select("Id", "computed.A", "B", "C", "D", "computed.E")

If you take your question's dataframe as input dataframe, you get the following result dataframe:

+---+---+---+---+---+---+
|Id |A  |B  |C  |D  |E  |
+---+---+---+---+---+---+
|1  |100|10 |20 |5  |75 |
|2  |75 |5  |10 |5  |60 |
|3  |60 |7  |2  |3  |50 |
|4  |50 |1  |3  |7  |40 |
+---+---+---+---+---+---+

Upvotes: 1

blackbishop
blackbishop

Reputation: 32640

You can use collect_list function over a Window ordered by Id column and get cumulative array of structs that hold the values of A and max(B, C) + D (as field T). Then, apply aggregate to calculate column E.

Note that in this particular case you can't use lag window function as you want the get calculated values recursively.

import org.apache.spark.sql.expressions.Window

val df2 = df.withColumn(
  "tmp",
  collect_list(
    struct(col("A"), (greatest(col("B"), col("C")) + col("D")).as("T"))
  ).over(Window.orderBy("Id"))
).withColumn(
  "E",
  expr("aggregate(transform(tmp, (x, i) -> IF(i=0, x.A - x.T, -x.T)), 0, (acc, x) -> acc + x)")
).withColumn(
  "A",
  col("E") + greatest(col("B"), col("C")) + col("D")
).drop("tmp")

df2.show(false)

//+---+---+---+---+---+---+
//|Id |A  |B  |C  |D  |E  |
//+---+---+---+---+---+---+
//|1  |100|10 |20 |5  |75 |
//|2  |75 |5  |10 |5  |60 |
//|3  |60 |7  |2  |3  |50 |
//|4  |50 |1  |3  |7  |40 |
//+---+---+---+---+---+---+

You can show the intermediary column tmp to understand the logic behind the calculation.

Upvotes: 1

Related Questions