Reputation: 789
Problem Statement: In pyspark ,I have to compare two columns diagonally for ex from below input dataframe I have to compare stn_fr_cd and stn_to_cd i.e,for val_no 1, have 2 rows . Now I have to compare stn_fr_cd of first row to stn_to_cd of second row and stn_to_cd of first row to stn_fr_cd of second row.
From below input dataframe,Since for val_no both stn_fr_Cd and stn_to_cd diagnoal elements are equal I have increment my value as 1
Below is my input having 4 columns id,val_no,stn_fr_cd,stn_to_cd
id val_no stn_fr_cd stn_to_cd
8A 1 CPH GDN
8A 1 GDN CPH
8A 2 GDN CPH
8A 2 CPH GDN
8A 3 CPH GDN
8A 3 GDN CPH
8A 4 CPH GDN
8A 4 GDN CPH
Below should be my output
8A 4
How I get 4 is for val_no 1,2,3,4 both stn_fr_cd and stn_to_cd diagonal elements are equal
Can anyone please help me with logic in pyspark pls. I really need to cross this hurdle please help with code
Upvotes: 1
Views: 930
Reputation: 8410
I think this what you want, I could be wrong. Let me know if it works for you or if I can update it. I used a window function to get the lead of both columns and if they are both equal then that partition will get a 1, otherwise 0, and then just grouped by id and summed my check column. I added 2 more rows(val_no=5) to show that they dont get selected because they dont satisfy both conditions of diagonals.
df.show()
+---+------+---------+---------+
| id|val_no|stn_fr_cd|stn_to_cd|
+---+------+---------+---------+
| 8A| 1| CPH| GDN|
| 8A| 1| GDN| CPH|
| 8A| 2| GDN| CPH|
| 8A| 2| CPH| GDN|
| 8A| 3| CPH| GDN|
| 8A| 3| GDN| CPH|
| 8A| 4| CPH| GDN|
| 8A| 4| GDN| CPH|
| 8A| 5| GDN| GDN|
| 8A| 5| CPH| GDN|
+---+------+---------+---------+
from pyspark.sql import functions as F
from pyspark.sql.window import Window
w=Window().partitionBy("id","val_no").orderBy("val_no")
df.withColumn("fr", F.lead("stn_fr_cd").over(w))\
.withColumn("to", F.lead("stn_to_cd").over(w))\
.withColumn("check", F.when((F.col("stn_fr_cd")==F.col("to"))&(F.col("stn_to_cd")==F.col("fr")),F.lit(1)).otherwise(F.lit(0)))\
.groupBy("id").agg(F.sum("check").alias("diagonals")).show()
+---+---------+
| id|diagonals|
+---+---------+
| 8A| 4|
+---+---------+
Upvotes: 1