Reputation: 1108
I have a dataframe as shown below with n columns.
+---+------------+--------+--------+--------+
|id | date|signal01|signal02|signal03|......signal(n)
+---+------------+--------+--------+--------+
|050|2021-01-14 |1 |3 |1 |
|050|2021-01-15 |null |4 |2 |
|050|2021-02-02 |2 |5 |3 |
|051|2021-01-14 |1 |3 |0 |
|051|2021-01-15 |null |null |null |
|051|2021-02-02 |3 |3 |2 |
|051|2021-02-03 |4 |3 |3 |
|052|2021-03-03 |1 |3 |0 |
|052|2021-03-05 |null |3 |null |
|052|2021-03-06 |null |null |2 |
|052|2021-03-16 |3 |5 |5 |.......value(n)
+-------------------------------------------+
I have to add a signal difference value column to each signal as below, excluding null records and keeping first diff value as 0.
+---+------------+--------+-------------+--------+-------------+--------+-------------+
|id | date|signal01|signal01_diff|signal02|signal02_diff|signal03|signal03_diff|......signal(n)
+---+------------+--------+-------------+--------+-------------+--------+-------------+
|050|2021-01-14 |1 |0 |3 |0 |1 |0 |
|050|2021-01-15 |null |null |4 |1 |2 |1 |
|050|2021-02-02 |2 |1 |5 |1 |3 |1 |
|051|2021-01-14 |1 |0 |3 |0 |0 |0 |
|051|2021-01-15 |null |null |null |null |null |null |
|051|2021-02-02 |3 |2 |3 |0 |2 |2 |
|051|2021-02-03 |4 |1 |3 |0 |3 |1 |
|052|2021-03-03 |1 |0 |3 |0 |0 |0 |
|052|2021-03-05 |null |null |3 |0 |null |null |
|052|2021-03-06 |null |null |null |null |2 |2 |
|052|2021-03-16 |3 |2 |5 |2 |5 |3 |.......value(n)
+-----------------------------------------------------------------------+--------------
I have tried with lag and window function, but not getting expected output because of null values.
val w = org.apache.spark.sql.expressions.Window.orderBy("id")
val dfWithLag = df.withColumn("signal01_lag", lag("signal01", 1, 0).over(w))
Above is code for a single column, I have to execute the same for rest n columns.
Is there any optimal way to achieve this?
Upvotes: 1
Views: 683
Reputation: 22439
That's a nice sample dataset for requirement illustration. Based on the expected output requirement, there are a couple of issues with your code:
orderBy("id")
, Window spec w
should be partitioned by "id" and ordered by "date"lag
, as you pointed out, won't be able to handle the null
signals among consecutive rowsThe approach shown below leverages Window function last
over rowsBetween()
to track down the last non-null
signals to compute the wanted row-wise signal difference:
val df = Seq(
("050", "2021-01-14", Some(1), Some(3), Some(1)),
("050", "2021-01-15", None, Some(4), Some(2)),
("050", "2021-02-02", Some(2), Some(5), Some(3)),
("051", "2021-01-14", Some(1), Some(3), Some(0)),
("051", "2021-01-15", None, None, None),
("051", "2021-02-02", Some(3), Some(3), Some(2)),
("051", "2021-02-03", Some(4), Some(3), Some(3)),
("052", "2021-03-03", Some(1), Some(3), Some(0)),
("052", "2021-03-05", None, Some(3), None),
("052", "2021-03-06", None, None, Some(2)),
("052", "2021-03-16", Some(3), Some(5), Some(5))
).toDF("id", "date", "signal01", "signal02", "signal03")
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy("id").orderBy("date").
rowsBetween(Window.unboundedPreceding, -1)
val signals = df.columns.filter(_ matches "signal\\d+")
val signalCols = signals.map(col)
val otherCols = df.columns.map(col) diff signalCols
df.select(
otherCols ++
signalCols ++
signals.map(s =>
(col(s) - coalesce(last(col(s), ignoreNulls=true).over(w), col(s))).as(s"${s}_diff")
): _*
).
orderBy("id", "date"). // only for ordered display
show
/*
+---+----------+--------+--------+--------+-------------+-------------+-------------+
| id| date|signal01|signal02|signal03|signal01_diff|signal02_diff|signal03_diff|
+---+----------+--------+--------+--------+-------------+-------------+-------------+
|050|2021-01-14| 1| 3| 1| 0| 0| 0|
|050|2021-01-15| null| 4| 2| null| 1| 1|
|050|2021-02-02| 2| 5| 3| 1| 1| 1|
|051|2021-01-14| 1| 3| 0| 0| 0| 0|
|051|2021-01-15| null| null| null| null| null| null|
|051|2021-02-02| 3| 3| 2| 2| 0| 2|
|051|2021-02-03| 4| 3| 3| 1| 0| 1|
|052|2021-03-03| 1| 3| 0| 0| 0| 0|
|052|2021-03-05| null| 3| null| null| 0| null|
|052|2021-03-06| null| null| 2| null| null| 2|
|052|2021-03-16| 3| 5| 5| 2| 2| 3|
+---+----------+--------+--------+--------+-------------+-------------+-------------+
*/
Upvotes: 3
Reputation: 89
You can use foldLeft to run over the col list and create your required new columns.
val cols= df.columns.toSeq
val newDf = cols.foldLeft(df)((df, col) =>
df.withColumn(s"$col_lag", lag(s"$col", 1, 0).over(w))
)
Upvotes: 0