Javier Monsalve
Javier Monsalve

Reputation: 326

Window function with lag based on another column

I have the following Spark DataFrame:

id month column_1 column_2
A 1 100 0
A 2 200 1
A 3 800 2
A 4 1500 3
A 5 1200 0
A 6 1600 1
A 7 2500 2
A 8 2800 3
A 9 3000 4

I would like to create a new column, let's call it 'dif_column1' based on a dynamic lag which is given by column_2. The desired output would be:

id month column_1 column_2 dif_column1
A 1 100 0 0
A 2 200 1 100
A 3 800 2 700
A 4 1500 3 1400
A 5 1200 0 0
A 6 1600 1 400
A 7 2500 2 1300
A 8 2800 3 1600
A 9 3000 4 1800

I have tried to use the lag function but apparently I can only use an integer with the lag function, so it does not work:

w = Window.partitionBy("id")
sdf = sdf.withColumn("dif_column1", F.col("column_1") - F.lag("column_1",F.col("column_2")).over(w))

Upvotes: 0

Views: 540

Answers (1)

mck
mck

Reputation: 42352

You can add a row number column, and do a self join based on the row number and the lag defined in column_2:

from pyspark.sql import functions as F, Window

w = Window.partitionBy("id").orderBy("month")

df1 = df.withColumn('rn', F.row_number().over(w)) 

df2 = df1.alias('t1').join(
    df1.alias('t2'),
    F.expr('(t1.id = t2.id) and (t1.rn = t2.rn + t1.column_2)'),
    'left'
).selectExpr(
    't1.*',
    't1.column_1 - t2.column_1 as dif_column1'
).drop('rn')

df2.show()
+---+-----+--------+--------+-----------+
| id|month|column_1|column_2|dif_column1|
+---+-----+--------+--------+-----------+
|  A|    1|     100|       0|          0|
|  A|    2|     200|       1|        100|
|  A|    3|     800|       2|        700|
|  A|    4|    1500|       3|       1400|
|  A|    5|    1200|       0|          0|
|  A|    6|    1600|       1|        400|
|  A|    7|    2500|       2|       1300|
|  A|    8|    2800|       3|       1600|
|  A|    9|    3000|       4|       1800|
+---+-----+--------+--------+-----------+

Upvotes: 1

Related Questions