Antony
Antony

Reputation: 1108

Find difference of column value in spark using scala

I have a dataframe as shown below with n columns.

+---+------------+--------+--------+--------+
|id |        date|signal01|signal02|signal03|......signal(n)
+---+------------+--------+--------+--------+
|050|2021-01-14  |1       |3       |1       |
|050|2021-01-15  |null    |4       |2       |
|050|2021-02-02  |2       |5       |3       |

|051|2021-01-14  |1       |3       |0       |
|051|2021-01-15  |null    |null    |null    |
|051|2021-02-02  |3       |3       |2       |
|051|2021-02-03  |4       |3       |3       |

|052|2021-03-03  |1       |3       |0       |
|052|2021-03-05  |null    |3       |null    |
|052|2021-03-06  |null    |null    |2       |
|052|2021-03-16  |3       |5       |5       |.......value(n)
+-------------------------------------------+

I have to add a signal difference value column to each signal as below, excluding null records and keeping first diff value as 0.

+---+------------+--------+-------------+--------+-------------+--------+-------------+
|id |        date|signal01|signal01_diff|signal02|signal02_diff|signal03|signal03_diff|......signal(n)
+---+------------+--------+-------------+--------+-------------+--------+-------------+
|050|2021-01-14  |1       |0            |3       |0            |1       |0            |
|050|2021-01-15  |null    |null         |4       |1            |2       |1            |
|050|2021-02-02  |2       |1            |5       |1            |3       |1            |
                                                                                      
|051|2021-01-14  |1       |0            |3       |0            |0       |0            |
|051|2021-01-15  |null    |null         |null    |null         |null    |null         |
|051|2021-02-02  |3       |2            |3       |0            |2       |2            |
|051|2021-02-03  |4       |1            |3       |0            |3       |1            |
                                                                                      
|052|2021-03-03  |1       |0            |3       |0            |0       |0            |
|052|2021-03-05  |null    |null         |3       |0            |null    |null         |
|052|2021-03-06  |null    |null         |null    |null         |2       |2            |
|052|2021-03-16  |3       |2            |5       |2            |5       |3            |.......value(n)
+-----------------------------------------------------------------------+--------------

I have tried with lag and window function, but not getting expected output because of null values.

val w = org.apache.spark.sql.expressions.Window.orderBy("id")
val dfWithLag = df.withColumn("signal01_lag", lag("signal01", 1, 0).over(w))

Above is code for a single column, I have to execute the same for rest n columns.

Is there any optimal way to achieve this?

Upvotes: 1

Views: 683

Answers (2)

Leo C
Leo C

Reputation: 22439

That's a nice sample dataset for requirement illustration. Based on the expected output requirement, there are a couple of issues with your code:

  1. Rather than orderBy("id"), Window spec w should be partitioned by "id" and ordered by "date"
  2. Window function lag , as you pointed out, won't be able to handle the null signals among consecutive rows

The approach shown below leverages Window function last over rowsBetween() to track down the last non-null signals to compute the wanted row-wise signal difference:

val df = Seq(
  ("050", "2021-01-14", Some(1), Some(3), Some(1)),
  ("050", "2021-01-15", None,    Some(4), Some(2)),
  ("050", "2021-02-02", Some(2), Some(5), Some(3)),

  ("051", "2021-01-14", Some(1), Some(3), Some(0)),
  ("051", "2021-01-15", None,    None,    None),
  ("051", "2021-02-02", Some(3), Some(3), Some(2)),
  ("051", "2021-02-03", Some(4), Some(3), Some(3)),

  ("052", "2021-03-03", Some(1), Some(3), Some(0)),
  ("052", "2021-03-05", None,    Some(3), None),
  ("052", "2021-03-06", None,    None,    Some(2)),
  ("052", "2021-03-16", Some(3), Some(5), Some(5))
).toDF("id", "date", "signal01", "signal02", "signal03")

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy("id").orderBy("date").
          rowsBetween(Window.unboundedPreceding, -1)

val signals = df.columns.filter(_ matches "signal\\d+")
val signalCols = signals.map(col)
val otherCols = df.columns.map(col) diff signalCols

df.select(
    otherCols ++
    signalCols ++
    signals.map(s =>
      (col(s) - coalesce(last(col(s), ignoreNulls=true).over(w), col(s))).as(s"${s}_diff")
    ): _*
  ).
  orderBy("id", "date").  // only for ordered display
  show
/*
+---+----------+--------+--------+--------+-------------+-------------+-------------+
| id|      date|signal01|signal02|signal03|signal01_diff|signal02_diff|signal03_diff|
+---+----------+--------+--------+--------+-------------+-------------+-------------+
|050|2021-01-14|       1|       3|       1|            0|            0|            0|
|050|2021-01-15|    null|       4|       2|         null|            1|            1|
|050|2021-02-02|       2|       5|       3|            1|            1|            1|
|051|2021-01-14|       1|       3|       0|            0|            0|            0|
|051|2021-01-15|    null|    null|    null|         null|         null|         null|
|051|2021-02-02|       3|       3|       2|            2|            0|            2|
|051|2021-02-03|       4|       3|       3|            1|            0|            1|
|052|2021-03-03|       1|       3|       0|            0|            0|            0|
|052|2021-03-05|    null|       3|    null|         null|            0|         null|
|052|2021-03-06|    null|    null|       2|         null|         null|            2|
|052|2021-03-16|       3|       5|       5|            2|            2|            3|
+---+----------+--------+--------+--------+-------------+-------------+-------------+
*/

Upvotes: 3

Tal
Tal

Reputation: 89

You can use foldLeft to run over the col list and create your required new columns.

val cols= df.columns.toSeq

val newDf = cols.foldLeft(df)((df, col) =>
  df.withColumn(s"$col_lag",  lag(s"$col", 1, 0).over(w))
)

Upvotes: 0

Related Questions