Reputation: 469
I have a pyspark dataframe with columns ( apart from some more columns) : There are multiple ids for every month. The active status for every id is determined by the amount column. If amount is > 0 then active = 1 else 0.
+-----------------------------+---
|id|amount| dates | active |
+-----------------------------+---
| X| 0|2019-05-01| 0 |
| X| 120|2019-06-01| 1 |
| Y| 60|2019-06-01| 1 |
| X| 0|2019-07-01| 0 |
| Y| 0|2019-07-01| 0 |
| Z| 50|2019-06-01| 1 |
| Y| 0|2019-07-01| 0 |
+-----------------------------+---
The new column I want to calculate and add is p3mactive. It is calculated on the basis of active status of past three months. Ex : For id = x, date = 2019-08-01, p3mactive = 1, since X is active in 2019-06-01. If months before that don't exist, then p3m active = 0. and if there are only 1 or 2 months then p3m active can simply be calculated as max(active(month-1), active(month-2)). basically on the basis of existing columns.
+-----------------------------+-----------+
|id|amount| dates | active | p3mactive |
+-----------------------------+-----------+
| X| 0|2019-05-01| 0 | 0 |
| X| 120|2019-06-01| 1 | 0 |
| Y| 60|2019-06-01| 1 | 0 |
| X| 0|2019-07-01| 0 | 1 |
| Y| 0|2019-07-01| 0 | 1 |
| Z| 50|2019-06-01| 1 | 0 |
| Y| 0|2019-07-01| 0 | 1 |
+-----------------------------+-----------+
So basically:
and so on. Let me know if there are any doubts about the flow.
I want to implement this using preferable dataframe operations and functions in pyspark. I can easily think of how to do this with pandas or python in general, but I'm new to spark and cannot think of a way to loop through ids, for every given month and then select previous three months' active status into the max(m1,m2,m3) function, keeping the edge conditions if prev months don't exist. Any help would be greatly appreciated.
Upvotes: 1
Views: 2285
Reputation: 7419
You can use when
and lag
using a Window
function to do this:
from pyspark.sql.window import Window
from pyspark.sql.functions import when, col, lag
w = Window().partitionBy("id").orderBy("dates")
df = df.withColumn("p3mactive", when(
(lag(df.active,1).over(w) == 1)|
(lag(df.active,2).over(w) == 1) |
(lag(df.active,3).over(w) == 1), 1).otherwise(0))
You cannot loop over pyspark dataframes, but you can stride over them by using Window
. You can apply conditions using when
and you can look at previous rows using lag
and future rows using lead
. If the row before x
doesn't exist, the condition evaluates to false and you will get a 0
as your use case mentions.
I hope this helps.
Upvotes: 1