syv
syv

Reputation: 3608

Extracting group data PySpark SQL

I have loaded the following data into DataFrame as below

S.NO  routename        status      tripcount

1      East           STARTED      1
2      West           STARTED      1
4      East           ARRIVED      2
5      East           ARRIVED      3
6      East           STARTED      4 
7      East           STARTED      5 
8      West           ARRIVED      2
9      East           ARRIVED      6

I want to take out only the following rows

1, 2, 4, 6, 8, 9

Basically STARTED - ARRIVED base rest of them I want to skip. Now I have loaded

dataframe_mysql.select("routename").distinct().show()

With this should I have to loop inside lambda expression or is there any other inbuilt method will help me to get the result.

Upvotes: 0

Views: 55

Answers (1)

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41957

You can benefit by using Window and lag functions. And you can use fillna, filter and drop functions to get your desired result.

from pyspark.sql import functions as F
from pyspark.sql.window import Window as W
windowSpec = W.partitionBy("routename").orderBy(F.col("S_NO"))

df.withColumnRenamed("S.NO", "S_NO").withColumn("remove", F.lag("status", 1).over(windowSpec))\
    .fillna({"remove":"nullString"})\
    .filter(F.col("status") != F.col("remove"))\
    .drop("remove")

Here we are grouping the Window function with routename column and ordering by S_NO. S.NO is renamed as it was creating problem after fillna function. lag function will copy status from previous status to new column remove. fillna will replace all null values to nullString StringType so that it can be considered in filter. Finally remove column is dropped.

If you want the output sorted you can use additional orderBy

.orderBy("S_NO")

And you should get output as

+----+---------+-------+---------+
|S_NO|routename|status |tripcount|
+----+---------+-------+---------+
|1   |East     |STARTED|1        |
|2   |West     |STARTED|1        |
|4   |East     |ARRIVED|2        |
|6   |East     |STARTED|4        |
|8   |West     |ARRIVED|2        |
|9   |East     |ARRIVED|6        |
+----+---------+-------+---------+

Hope the answer is more than helpful

Update

As @syv pointed out that lag has default value parameter option which can be used when value is not found for lag so that fillna function call can totally be removed and even the column renaming is not needed at all

from pyspark.sql import functions as F
from pyspark.sql.window import Window as W
windowSpec = W.partitionBy("routename").orderBy(F.col("`S.NO`"))

df.withColumn("remove", F.lag("status", 1, "nullString").over(windowSpec))\
    .filter(F.col("status") != F.col("remove"))\
    .drop("remove")\
    .orderBy(F.col("`S.NO`"))

which should give you

+----+---------+-------+---------+
|S.NO|routename|status |tripcount|
+----+---------+-------+---------+
|1   |East     |STARTED|1        |
|2   |West     |STARTED|1        |
|4   |East     |ARRIVED|2        |
|6   |East     |STARTED|4        |
|8   |West     |ARRIVED|2        |
|9   |East     |ARRIVED|6        |
+----+---------+-------+---------+

Upvotes: 2

Related Questions