Sirapat
Sirapat

Reputation: 494

pyspark Apply DataFrame window function with filter

I have a dataset with the column: id,timestamp,x,y

id  timestamp   x      y 
0   1443489380  100    1
0   1443489390  200    0
0   1443489400  300    0
0   1443489410  400    1

I defined a window spec: w = Window.partitionBy("id").orderBy("timestamp")

I want to do something like this. Create a new column that sum x of current row with x of next row.

If sum >= 500 then set new column = BIG else SMALL.

df = df.withColumn("newCol", 
                   when(df.x + lag(df.x,-1).over(w) >= 500 , "BIG")
                   .otherwise("SMALL") )

However, I want to filter the data before do this without affecting original df.

[Only row with y =1 will apply the above code]

So the data that will apply above code is only these 2 rows.

0 , 1443489380, 100 , 1

0 , 1443489410, 400 , 1

I have done this way but it is too bad.

df2 = df.filter(df.y == 1)
df2 = df2.withColumn("newCol", 
                     when(df.x + lag(df.x,-1).over(w) >= 500 , "BIG")
                     .otherwise("SMALL") )
df = df.join(df2, ["id","timestamp"], "outer")

I want to do something like this but it's not possible since it will cause AttributeError: 'DataFrame' object has no attribute 'when'

df = df.withColumn("newCol", df.filter(df.y == 1)
                   .when(df.x + lag(df.x,-1).over(w) >= 500 , "BIG")
                   .otherwise("SMALL") )

In conclusion, I just want to do a temporary filter for only row with y =1 before sum x with next x.

Upvotes: 6

Views: 19165

Answers (1)

Suresh
Suresh

Reputation: 5870

Your code works fine, I think you din import functions module. Tried your code,

>>> from pyspark.sql import functions as F
>>> df2 = df2.withColumn("newCol", 
                 F.when((df.x + F.lag(df.x,-1).over(w))>= 500 , "BIG")
                 .otherwise("SMALL") )
>>> df2.show()
+---+----------+---+---+------+
| id| timestamp|  x|  y|newCol|
+---+----------+---+---+------+
|  0|1443489380|100|  1|   BIG|
|  0|1443489410|400|  1| SMALL|
+---+----------+---+---+------+

Edited : Have tried by changing the window partition based on 'id','y' columns,

>>> w = Window.partitionBy("id","y").orderBy("timestamp")
>>> df.select("*", F.when(df.y == 1,F.when((df.x+F.lag("x",-1).over(w)) >=500,'BIG').otherwise('SMALL')).otherwise(None).alias('new_col')).show()
+---+----------+---+---+-------+
| id| timestamp|  x|  y|new_col|
+---+----------+---+---+-------+
|  0|1443489380|100|  1|    BIG|
|  0|1443489410|400|  1|  SMALL|
|  0|1443489390|200|  0|   null|
|  0|1443489400|300|  0|   null|
+---+----------+---+---+-------+

Upvotes: 9

Related Questions