Reputation: 4790
I have a dataframe like this
data = [(("ID1", 1, 5)), (("ID1", 2, 6)), (("ID1", 3, 7)),
(("ID1", 4, 4)), (("ID1", 5, 2)), (("ID1", 6, 2)),
(("ID2", 1, 4)), (("ID2", 2, 6)), (("ID2", 3, 1)), (("ID2", 4, 1)), (("ID2", 5, 4))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()
+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1| 1| 5|
|ID1| 2| 6|
|ID1| 3| 7|
|ID1| 4| 4|
|ID1| 5| 2|
|ID1| 6| 2|
|ID2| 1| 4|
|ID2| 2| 6|
|ID2| 3| 1|
|ID2| 4| 1|
|ID2| 5| 4|
+---+----+----+
I want to calculate last 3 correlation and average, per group, of last 3 elements.
Hence for ID1, for first element (5) - Average = 5, corr = 0
for ID1, for first 2 element (5, 6) - Average = 5.5, corr with colA = 1
for ID1, for first 3 element (5, 6, 7) - Average = 6, corr with colA = 1
for ID1, for elements (6, 7, 4) - Average = 5.66, corr with colA = -0.65
Expected output is like this
+---+----+----+----------+---------+
| ID|colA|colB|corr_last3|avg_last3|
+---+----+----+----------+---------+
|ID1| 1| 5| 0| 5|
|ID1| 2| 6| 1| 5.5|
|ID1| 3| 7| 1| 6|
|ID1| 4| 4| -0.65| 5.66|
|ID1| 5| 2| -0.99| 4.33|
|ID1| 6| 2| -0.86| 2.66|
|ID2| 1| 4| 0| 4|
|ID2| 2| 6| 1| 5|
|ID2| 3| 1| -0.59| 3.66|
|ID2| 4| 1| -0.86| 2.66|
|ID2| 5| 4| 0.86| 2|
+---+----+----+----------+---------+
Upvotes: 1
Views: 945
Reputation: 374
All you need to play with rowsBetween
in the below code, so please read the documentation of this function, you can do any thing once you do partition.
from pyspark.sql import Window, functions as F
windowSpec = Window.partitionBy("id1", "id2").orderBy("Year", "Week").rowsBetween(Window.currentRow, 3)
df = df.withColumn(
"correlation_4q",
F.corr(F.col("price1"), F.col("price2")).over(windowSpec)
)
Upvotes: 0
Reputation: 4790
Pyspark
version of the answer is this
from pyspark.sql import Window
from pyspark.sql.functions import rank, corr, when, mean, col, round
df = df\
.withColumn("indices",rank().over(Window.partitionBy("ID").orderBy("colA")))\
.withColumn("corr_last3", when(col("indices") > 1, corr(col("indices"), col("colB"))
.over(Window.partitionBy("ID").orderBy("colA")
.rangeBetween(-2, Window.currentRow))).otherwise(0.0))\
.withColumn("avg_last3", mean(col("colB")).over(Window.partitionBy("ID").orderBy("colA").rangeBetween(-2, Window.currentRow)))\
.drop(col("indices"))\
.orderBy("ID","colA")
df = df.withColumn("corr_last3", round(col("corr_last3"), 3))\
.withColumn("avg_last3", round(col("corr_last3"), 3))
df.show()
+---+----+----+----------+---------+
| ID|colA|colB|corr_last3|avg_last3|
+---+----+----+----------+---------+
|ID1| 1| 5| 0.0| 0.0|
|ID1| 2| 6| 1.0| 1.0|
|ID1| 3| 7| 1.0| 1.0|
|ID1| 4| 4| -0.655| -0.655|
|ID1| 5| 2| -0.993| -0.993|
|ID1| 6| 2| -0.866| -0.866|
|ID2| 1| 4| 0.0| 0.0|
|ID2| 2| 6| 1.0| 1.0|
|ID2| 3| 1| -0.596| -0.596|
|ID2| 4| 1| -0.866| -0.866|
|ID2| 5| 4| 0.866| 0.866|
+---+----+----+----------+---------+
Upvotes: 0
Reputation: 27373
You can do it with built-in functions avg
and corr
, here the scala solution :
df
.withColumn("indices",row_number().over(Window.partitionBy($"ID").orderBy($"colA")))
.withColumn("corr_last3", when($"indices">1,corr($"indices",$"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow))).otherwise(0.0))
.withColumn("avg_last3", avg($"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow)))
.drop($"indices")
.orderBy($"ID",$"colA")
.show()
gives:
+---+----+----+-------------------+------------------+
| ID|colA|colB| corr_last3| avg_last3|
+---+----+----+-------------------+------------------+
|ID1| 1| 5| 0.0| 5.0|
|ID1| 2| 6| 1.0| 5.5|
|ID1| 3| 7| 1.0| 6.0|
|ID1| 4| 4|-0.6546536707079772| 5.666666666666667|
|ID1| 5| 2|-0.9933992677987828| 4.333333333333333|
|ID1| 6| 2|-0.8660254037844386|2.6666666666666665|
|ID2| 1| 4| 0.0| 4.0|
|ID2| 2| 6| 1.0| 5.0|
|ID2| 3| 1|-0.5960395606792697|3.6666666666666665|
|ID2| 4| 1|-0.8660254037844387|2.6666666666666665|
|ID2| 5| 4| 0.8660254037844387| 2.0|
+---+----+----+-------------------+------------------+
Upvotes: 3