Naveen Srikanth
Naveen Srikanth

Reputation: 789

pyspark compare two columns diagnolly

Problem Statement: In pyspark ,I have to compare two columns diagonally for ex from below input dataframe I have to compare stn_fr_cd and stn_to_cd i.e,for val_no 1, have 2 rows . Now I have to compare stn_fr_cd of first row to stn_to_cd of second row and stn_to_cd of first row to stn_fr_cd of second row.

From below input dataframe,Since for val_no both stn_fr_Cd and stn_to_cd diagnoal elements are equal I have increment my value as 1

Below is my input having 4 columns id,val_no,stn_fr_cd,stn_to_cd

id val_no    stn_fr_cd stn_to_cd

8A   1        CPH      GDN                  

8A   1        GDN      CPH                  

8A   2        GDN      CPH                  

8A   2        CPH      GDN                  

8A   3        CPH      GDN                  

8A   3        GDN      CPH                  

8A   4        CPH      GDN                  

8A   4        GDN      CPH 

Below should be my output

8A 4

How I get 4 is for val_no 1,2,3,4 both stn_fr_cd and stn_to_cd diagonal elements are equal

Can anyone please help me with logic in pyspark pls. I really need to cross this hurdle please help with code

Upvotes: 1

Views: 930

Answers (1)

murtihash
murtihash

Reputation: 8410

I think this what you want, I could be wrong. Let me know if it works for you or if I can update it. I used a window function to get the lead of both columns and if they are both equal then that partition will get a 1, otherwise 0, and then just grouped by id and summed my check column. I added 2 more rows(val_no=5) to show that they dont get selected because they dont satisfy both conditions of diagonals.

df.show()

+---+------+---------+---------+
| id|val_no|stn_fr_cd|stn_to_cd|
+---+------+---------+---------+
| 8A|     1|      CPH|      GDN|
| 8A|     1|      GDN|      CPH|
| 8A|     2|      GDN|      CPH|
| 8A|     2|      CPH|      GDN|
| 8A|     3|      CPH|      GDN|
| 8A|     3|      GDN|      CPH|
| 8A|     4|      CPH|      GDN|
| 8A|     4|      GDN|      CPH|
| 8A|     5|      GDN|      GDN|
| 8A|     5|      CPH|      GDN|
+---+------+---------+---------+

from pyspark.sql import functions as F
from pyspark.sql.window import Window
w=Window().partitionBy("id","val_no").orderBy("val_no")

df.withColumn("fr", F.lead("stn_fr_cd").over(w))\
  .withColumn("to", F.lead("stn_to_cd").over(w))\
  .withColumn("check", F.when((F.col("stn_fr_cd")==F.col("to"))&(F.col("stn_to_cd")==F.col("fr")),F.lit(1)).otherwise(F.lit(0)))\
  .groupBy("id").agg(F.sum("check").alias("diagonals")).show()

+---+---------+
| id|diagonals|
+---+---------+
| 8A|        4|
+---+---------+

Upvotes: 1

Related Questions