Reputation: 3608
I have loaded the following data into DataFrame as below
S.NO routename status tripcount
1 East STARTED 1
2 West STARTED 1
4 East ARRIVED 2
5 East ARRIVED 3
6 East STARTED 4
7 East STARTED 5
8 West ARRIVED 2
9 East ARRIVED 6
I want to take out only the following rows
1, 2, 4, 6, 8, 9
Basically STARTED - ARRIVED base rest of them I want to skip. Now I have loaded
dataframe_mysql.select("routename").distinct().show()
With this should I have to loop inside lambda expression or is there any other inbuilt method will help me to get the result.
Upvotes: 0
Views: 55
Reputation: 41957
You can benefit by using Window
and lag
functions. And you can use fillna
, filter
and drop
functions to get your desired result.
from pyspark.sql import functions as F
from pyspark.sql.window import Window as W
windowSpec = W.partitionBy("routename").orderBy(F.col("S_NO"))
df.withColumnRenamed("S.NO", "S_NO").withColumn("remove", F.lag("status", 1).over(windowSpec))\
.fillna({"remove":"nullString"})\
.filter(F.col("status") != F.col("remove"))\
.drop("remove")
Here we are grouping the Window
function with routename
column and ordering by S_NO
. S.NO
is renamed as it was creating problem after fillna
function. lag
function will copy status
from previous status to new column remove
. fillna
will replace all null values to nullString StringType so that it can be considered in filter
. Finally remove
column is drop
ped.
If you want the output sorted you can use additional orderBy
.orderBy("S_NO")
And you should get output as
+----+---------+-------+---------+
|S_NO|routename|status |tripcount|
+----+---------+-------+---------+
|1 |East |STARTED|1 |
|2 |West |STARTED|1 |
|4 |East |ARRIVED|2 |
|6 |East |STARTED|4 |
|8 |West |ARRIVED|2 |
|9 |East |ARRIVED|6 |
+----+---------+-------+---------+
Hope the answer is more than helpful
Update
As @syv pointed out that lag
has default value parameter option which can be used when value is not found for lag
so that fillna
function call can totally be removed and even the column renaming is not needed at all
from pyspark.sql import functions as F
from pyspark.sql.window import Window as W
windowSpec = W.partitionBy("routename").orderBy(F.col("`S.NO`"))
df.withColumn("remove", F.lag("status", 1, "nullString").over(windowSpec))\
.filter(F.col("status") != F.col("remove"))\
.drop("remove")\
.orderBy(F.col("`S.NO`"))
which should give you
+----+---------+-------+---------+
|S.NO|routename|status |tripcount|
+----+---------+-------+---------+
|1 |East |STARTED|1 |
|2 |West |STARTED|1 |
|4 |East |ARRIVED|2 |
|6 |East |STARTED|4 |
|8 |West |ARRIVED|2 |
|9 |East |ARRIVED|6 |
+----+---------+-------+---------+
Upvotes: 2