Yanikovic
Yanikovic

Reputation: 49

Drop consecutive duplicates in a Spark dataframe

The situation is the following: I have a time-series Dataframe consisting of one index column which orders the sequence; and a column of some discrete value like this:

id    value
0     A
1     A
2     B
3     C
4     A
5     A
6     A
7     B

I now want to reduce all consecutive duplicates, so that it looks like this:

id    value
0     A
2     B
3     C
4     A
7     B

I've come up with a window and using lag(), when() and filtering afterwards. The problem is that window requires a specific partition column. What I want however is to just drop the consecutive rows in each partition first and after that check for the partition borders (since the window works per partition, so consecutive rows over partition borders still exist).

df_with_block = df.withColumn(
            "block", (col("id") / df.rdd.getNumPartitions()).cast("int"))

window = Window.partitionBy("block").orderBy("id")

get_last = when(lag("value", 1).over(window) == col("value"), False).otherwise(True)

reduced_df = unificated_with_block.withColumn("reduced",get_last)
                .where(col("reduced")).drop("reduced")

In the first line, I created a new dataframe with uniformly distributed partitions by integer dividing the id's. get_last then contains boolean values depending on the current rows being equal to the preceding. reduced_df then filters the duplicates out.

The problem are now the partition borders:

id    value
0     A
2     B
3     C
4     A
6     A
7     B

As you can see, row with id=6 didn't get removed since it was processed in a different partition. I'm thinking about different ideas to solve this:

I'm curious how that could work out.

Upvotes: 3

Views: 1829

Answers (1)

fathomson
fathomson

Reputation: 173

Without partitioning:

You can use window without partition, using the same logic you already used.

from pyspark.sql.window import *
import pyspark.sql.functions as F  
  
data = [(0,"A"), (1,"A"),(2,"B"),(3,"C"),(4,"A"),(5,"A"),(6,"A"),(7,"B")]
df = sqlContext.createDataFrame(data, ["id","value"])

w = Window().orderBy(F.col("id"))
df = df.withColumn("dupe", F.col("value") == F.lag("value").over(w))\
.filter((F.col("dupe") == False) | (F.col("dupe").isNull())).drop("dupe")

df.show()

Resulting in:

+---+-----+
| id|value|
+---+-----+
|  0|    A|
|  2|    B|
|  3|    C|
|  4|    A|
|  7|    B|
+---+-----+

With partitioning:

Another solution with partitioning would be to partition it by value resulting in: Assuming that the id of a duplicate record is only increased by 1.

w = Window().partitionBy("value").orderBy(F.col("id"))
df = df.withColumn("dupe", F.col("id") - F.lag("id").over(w))\
.filter((F.col("dupe") > 1) | (F.col("dupe").isNull())).drop("dupe")\
.orderBy("id")

Upvotes: 4

Related Questions