Reputation: 1913
Supposing I have a Dataframe like below
Id | A | B | C | D |
---|---|---|---|---|
1 | 100 | 10 | 20 | 5 |
2 | 0 | 5 | 10 | 5 |
3 | 0 | 7 | 2 | 3 |
4 | 0 | 1 | 3 | 7 |
And the above needs to be converted to something like below
Id | A | B | C | D | E |
---|---|---|---|---|---|
1 | 100 | 10 | 20 | 5 | 75 |
2 | 75 | 5 | 10 | 5 | 60 |
3 | 60 | 7 | 2 | 3 | 50 |
4 | 50 | 1 | 3 | 7 | 40 |
The thing works by the details provided below
col(A) - (max(col(B), col(C)) + col(D))
=> 100-(max(10,20) + 5)
= 75Id
2, the value of col E from row 1 is brough forward as the value for Col A
E
, is determined as 75-(max(5,10) + 5)
= 60Id
3, the value of A becomes 60 and the new value for col E
is determined based on thisThe problem is, the value of col A is dependent on the previous row's values except for the first row
Is there a possibility to solve this using windowing and lag
Upvotes: 1
Views: 1914
Reputation: 5068
As blackbishop said, you can't use lag function to retrieve changing value of a column. As you're using the scala API, you can develop your own User-Defined Aggregate Function
You create the following case classes, representing the row you're currently reading and your aggregator's buffer:
case class InputRow(A: Integer, B: Integer, C: Integer, D: Integer)
case class Buffer(var E: Integer, var A: Integer)
Then you use them to define your RecursiveAggregator
custom aggregator:
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
object RecursiveAggregator extends Aggregator[InputRow, Buffer, Buffer] {
override def zero: Buffer = Buffer(null, null)
override def reduce(buffer: Buffer, currentRow: InputRow): Buffer = {
buffer.A = if (buffer.E == null) currentRow.A else buffer.E
buffer.E = buffer.A - (math.max(currentRow.B, currentRow.C) + currentRow.D)
buffer
}
override def merge(b1: Buffer, b2: Buffer): Buffer = {
throw new NotImplementedError("should be used only over ordered window")
}
override def finish(reduction: Buffer): Buffer = reduction
override def bufferEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]
override def outputEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]
}
Finally you transform your RecursiveAggregator
to an User-Defined aggregate function that you apply on your input
dataframe:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, udaf}
val recursiveAggregator = udaf(RecursiveAggregator)
val window = Window.orderBy("Id")
val result = input
.withColumn("computed", recursiveAggregator(col("A"), col("B"), col("C"), col("D")).over(window))
.select("Id", "computed.A", "B", "C", "D", "computed.E")
If you take your question's dataframe as input
dataframe, you get the following result
dataframe:
+---+---+---+---+---+---+
|Id |A |B |C |D |E |
+---+---+---+---+---+---+
|1 |100|10 |20 |5 |75 |
|2 |75 |5 |10 |5 |60 |
|3 |60 |7 |2 |3 |50 |
|4 |50 |1 |3 |7 |40 |
+---+---+---+---+---+---+
Upvotes: 1
Reputation: 32640
You can use collect_list
function over a Window ordered by Id
column and get cumulative array of structs that hold the values of A
and max(B, C) + D
(as field T
). Then, apply aggregate
to calculate column E
.
Note that in this particular case you can't use lag
window function as you want the get calculated values recursively.
import org.apache.spark.sql.expressions.Window
val df2 = df.withColumn(
"tmp",
collect_list(
struct(col("A"), (greatest(col("B"), col("C")) + col("D")).as("T"))
).over(Window.orderBy("Id"))
).withColumn(
"E",
expr("aggregate(transform(tmp, (x, i) -> IF(i=0, x.A - x.T, -x.T)), 0, (acc, x) -> acc + x)")
).withColumn(
"A",
col("E") + greatest(col("B"), col("C")) + col("D")
).drop("tmp")
df2.show(false)
//+---+---+---+---+---+---+
//|Id |A |B |C |D |E |
//+---+---+---+---+---+---+
//|1 |100|10 |20 |5 |75 |
//|2 |75 |5 |10 |5 |60 |
//|3 |60 |7 |2 |3 |50 |
//|4 |50 |1 |3 |7 |40 |
//+---+---+---+---+---+---+
You can show the intermediary column tmp
to understand the logic behind the calculation.
Upvotes: 1