Hardik Gupta
Hardik Gupta

Reputation: 4790

Rolling correlation and average (last 3) Per Group in PySpark

I have a dataframe like this

data = [(("ID1", 1, 5)), (("ID1", 2, 6)), (("ID1", 3, 7)),
    (("ID1", 4, 4)), (("ID1", 5, 2)), (("ID1", 6, 2)),
    (("ID2", 1, 4)), (("ID2", 2, 6)), (("ID2", 3, 1)), (("ID2", 4, 1)), (("ID2", 5, 4))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()

+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1|   1|   5|
|ID1|   2|   6|
|ID1|   3|   7|
|ID1|   4|   4|
|ID1|   5|   2|
|ID1|   6|   2|
|ID2|   1|   4|
|ID2|   2|   6|
|ID2|   3|   1|
|ID2|   4|   1|
|ID2|   5|   4|
+---+----+----+

I want to calculate last 3 correlation and average, per group, of last 3 elements.

Hence for ID1, for first element (5) - Average = 5, corr = 0
for ID1, for first 2 element (5, 6) - Average = 5.5, corr with colA = 1
for ID1, for first 3 element (5, 6, 7) - Average = 6, corr with colA = 1
for ID1, for elements (6, 7, 4) - Average = 5.66, corr with colA = -0.65


Expected output is like this

    +---+----+----+----------+---------+
    | ID|colA|colB|corr_last3|avg_last3|
    +---+----+----+----------+---------+
    |ID1|   1|   5|         0|        5|
    |ID1|   2|   6|         1|      5.5|
    |ID1|   3|   7|         1|        6|
    |ID1|   4|   4|     -0.65|     5.66|
    |ID1|   5|   2|     -0.99|     4.33|
    |ID1|   6|   2|     -0.86|     2.66|
    |ID2|   1|   4|         0|        4|
    |ID2|   2|   6|         1|        5|
    |ID2|   3|   1|     -0.59|     3.66|
    |ID2|   4|   1|     -0.86|     2.66|
    |ID2|   5|   4|      0.86|        2|
    +---+----+----+----------+---------+

Upvotes: 1

Views: 945

Answers (3)

Shubham Tomar
Shubham Tomar

Reputation: 374

All you need to play with rowsBetween in the below code, so please read the documentation of this function, you can do any thing once you do partition.

from pyspark.sql import Window, functions as F
windowSpec = Window.partitionBy("id1", "id2").orderBy("Year", "Week").rowsBetween(Window.currentRow, 3)
df = df.withColumn(
"correlation_4q",
F.corr(F.col("price1"), F.col("price2")).over(windowSpec)

)

Upvotes: 0

Hardik Gupta
Hardik Gupta

Reputation: 4790

Pyspark version of the answer is this

from pyspark.sql import Window
from pyspark.sql.functions import rank, corr, when, mean, col, round

df = df\
      .withColumn("indices",rank().over(Window.partitionBy("ID").orderBy("colA")))\
      .withColumn("corr_last3", when(col("indices") > 1, corr(col("indices"), col("colB"))
                                     .over(Window.partitionBy("ID").orderBy("colA")
                                           .rangeBetween(-2, Window.currentRow))).otherwise(0.0))\
      .withColumn("avg_last3", mean(col("colB")).over(Window.partitionBy("ID").orderBy("colA").rangeBetween(-2, Window.currentRow)))\
      .drop(col("indices"))\
      .orderBy("ID","colA")

df = df.withColumn("corr_last3", round(col("corr_last3"), 3))\
       .withColumn("avg_last3", round(col("corr_last3"), 3))
df.show() 


+---+----+----+----------+---------+
| ID|colA|colB|corr_last3|avg_last3|
+---+----+----+----------+---------+
|ID1|   1|   5|       0.0|      0.0|
|ID1|   2|   6|       1.0|      1.0|
|ID1|   3|   7|       1.0|      1.0|
|ID1|   4|   4|    -0.655|   -0.655|
|ID1|   5|   2|    -0.993|   -0.993|
|ID1|   6|   2|    -0.866|   -0.866|
|ID2|   1|   4|       0.0|      0.0|
|ID2|   2|   6|       1.0|      1.0|
|ID2|   3|   1|    -0.596|   -0.596|
|ID2|   4|   1|    -0.866|   -0.866|
|ID2|   5|   4|     0.866|    0.866|
+---+----+----+----------+---------+

Upvotes: 0

Raphael Roth
Raphael Roth

Reputation: 27373

You can do it with built-in functions avg and corr, here the scala solution :

df
  .withColumn("indices",row_number().over(Window.partitionBy($"ID").orderBy($"colA")))
  .withColumn("corr_last3", when($"indices">1,corr($"indices",$"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow))).otherwise(0.0))
  .withColumn("avg_last3", avg($"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow)))
  .drop($"indices")
  .orderBy($"ID",$"colA")
  .show() 

gives:

+---+----+----+-------------------+------------------+
| ID|colA|colB|         corr_last3|         avg_last3|
+---+----+----+-------------------+------------------+
|ID1|   1|   5|                0.0|               5.0|
|ID1|   2|   6|                1.0|               5.5|
|ID1|   3|   7|                1.0|               6.0|
|ID1|   4|   4|-0.6546536707079772| 5.666666666666667|
|ID1|   5|   2|-0.9933992677987828| 4.333333333333333|
|ID1|   6|   2|-0.8660254037844386|2.6666666666666665|
|ID2|   1|   4|                0.0|               4.0|
|ID2|   2|   6|                1.0|               5.0|
|ID2|   3|   1|-0.5960395606792697|3.6666666666666665|
|ID2|   4|   1|-0.8660254037844387|2.6666666666666665|
|ID2|   5|   4| 0.8660254037844387|               2.0|
+---+----+----+-------------------+------------------+

Upvotes: 3

Related Questions