Sreeram TP
Sreeram TP

Reputation: 11937

Find next different value from lag in pyspark

I have a pyspark dataframe like this,

+-----+----------+
|value|val_joined|
+-----+----------+
|    3|         3|
|    4|       3+4|
|    5|     3+4+5|
|    5|     3+4+5|
|    5|     3+4+5|
|    2|   3+4+5+2|
+-----+----------+

From this, I have to create another column that goes like this,

+-----+----------+------+
|value|val_joined|result|
+-----+----------+------+
|    3|         3|   4.0|
|    4|       3+4|   5.0|
|    5|     3+4+5|   2.0|
|    5|     3+4+5|   2.0|
|    5|     3+4+5|   2.0|
|    2|   3+4+5+2|   NaN|
+-----+----------+------+

The result column is to be made like this, For an item in column named value, find the next item coming in order. So for value 3 it will be 4 and for value 4 it will be 5.

But when there are duplicates like the value 5 that repeats 3 times simple lag won't work. As the lag for first 5 will result in 5. I basically want to repeat taking lag till the value != lag(value) or lag(value) is null.

How can I do this in pyspark without udf and joins?

Upvotes: 2

Views: 1042

Answers (1)

anky
anky

Reputation: 75150

We can take 2 windows and find the next row value once with 1st window by assigning a monotonically_increasing_id and the last value in the other window like below:

import pyspark.sql.functions as F
w = Window.orderBy('idx')
w1 = Window.partitionBy('value')

(df.withColumn('idx',F.monotonically_increasing_id())
.withColumn("result",F.last(F.lead("value").over(w)).over(w1)).orderBy('idx')
.drop('idx')).show()

+-----+----------+------+
|value|val_joined|result|
+-----+----------+------+
|    3|         3|     4|
|    4|       3+4|     5|
|    5|     3+4+5|     2|
|    5|     3+4+5|     2|
|    5|     3+4+5|     2|
|    2|   3+4+5+2|  null|
+-----+----------+------+

If numbers in value can repeat later example below:

+-----+----------+
|value|val_joined|
+-----+----------+
|3    |3         |
|4    |3+4       |
|5    |3+4+5     |
|5    |3+4+5     |
|5    |3+4+5     |
|2    |3+4+5+2   |
|5    |3+4+5+2+5 | <- this value is repeated later
+-----+----------+

Then we will have to create a seperate group and take the group as window:

w = Window.orderBy('idx')
w1 = Window.partitionBy('group')

(df.withColumn('idx',F.monotonically_increasing_id())
  .withColumn("lag", F.when(F.lag("value").over(w)!=F.col("value"), F.lit(1))
  .otherwise(F.lit(0)))
  .withColumn("group", F.sum("lag").over(w) + 1).drop("lag")
  .withColumn("result",F.last(F.lead("value").over(w)).over(w1)).orderBy('idx')
  .drop('idx',"group")).show()

+-----+----------+------+
|value|val_joined|result|
+-----+----------+------+
|    3|         3|     4|
|    4|       3+4|     5|
|    5|     3+4+5|     2|
|    5|     3+4+5|     2|
|    5|     3+4+5|     2|
|    2|   3+4+5+2|     5|
|    5| 3+4+5+2+5|  null|
+-----+----------+------+

Upvotes: 2

Related Questions