Reputation: 1256
I am trying to find the duplicate count of rows in a pyspark dataframe. I found a similar answer here but it only outputs a binary flag. I would like to have the actual count for each row. To use the orignal post's example, if I have a dataframe like so:
+--+--+--+--+
|a |b |c |d |
+--+--+--+--+
|1 |0 |1 |2 |
|0 |2 |0 |1 |
|1 |0 |1 |2 |
|0 |4 |3 |1 |
|1 |0 |1 |2 |
+--+--+--+--+
I would like to result in something like:
+--+--+--+--+--+--+--+--+
|a |b |c |d |row_count |
+--+--+--+--+--+--+--+--+
|1 |0 |1 |2 |3 |
|0 |2 |0 |1 |0 |
|1 |0 |1 |2 |3 |
|0 |4 |3 |1 |0 |
|1 |0 |1 |2 |3 |
+--+--+--+--+--+--+--+--+
Is this possible? Thank You
Upvotes: 0
Views: 602
Reputation: 13387
Assuming df
is your input dataframe:
from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.sql.functions import *
w = (Window.partitionBy([F.col("a"), F.col("b"), F.col("c"), F.col("D")]))
df=df.select(F.col("a"), F.col("b"), F.col("c"), F.col("D"), F.count(F.col("a")).over(w).alias("row_count"))
If, as per your example, you want to replace every count 1
with 0
do:
from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.sql.functions import *
w = (Window.partitionBy([F.col("a"), F.col("b"), F.col("c"), F.col("D")]))
df=df.select(F.col("a"), F.col("b"), F.col("c"), F.col("D"), F.count(F.col("a")).over(w).alias("row_count")).select("a", "b", "c", "d", F.when(F.col("row_count")==F.lit(1), F.lit(0)). otherwise(F.col("row_count")).alias("row_count"))
Upvotes: 2