Nik
Nik

Reputation: 5745

Dedupe rows in Spark DataFrame by most recent timestamp

I have a DataFrame with the following schema:

root
|- documentId
|- timestamp
|- anotherField

For example,

"d1", "2018-09-20 10:00:00", "blah1"
"d2", "2018-09-20 09:00:00", "blah2"
"d1", "2018-09-20 10:01:00", "blahnew"

Note that for the sake of understanding (and my convenience) I am showing the timestamp as a string. It is in fact a long representing milliseconds since epoch.

As seen here, there are duplicate rows (row 1 and 3) with the same documentId but different timestamp (and possibly different other fields). I want to dedupe and retain only the most recent (based on timestamp) row for each documentId.

A simple df.groupBy("documentId").agg(max("timestamp), ...) does not seem likely to work here because I don't know how to retain the other fields in the row corresponding to the one that satisfies max("timestamp").

So, I came up with a complicated way of doing this.

// first find the max timestamp corresponding to each documentId
val mostRecent = df
    .select("documentId", "timestamp")
      .groupBy("documentId")
        .agg(max("timestamp"))

// now join with the original df on timestamp to retain
val dedupedDf = df.join(mostRecent, Seq("documentId", "timestamp"), "inner")

This resulting dedupedDf should have only those rows which correspond to the most recent entry for each documentId.

Although this works, I don't feel this is the right (or efficient) approach, since I am using a join which seems needless.

How can I do it better? I am looking for pure "DataFrame" based solutions as opposed to RDD-based approaches (since DataBricks folks have repeatedly told us in a workshop to work with DataFrames and not RDDs).

Upvotes: 1

Views: 3320

Answers (1)

Karthick
Karthick

Reputation: 662

See the below code helps your objective,

val df = Seq(
  ("d1", "2018-09-20 10:00:00", "blah1"),
  ("d2", "2018-09-20 09:00:00", "blah2"),
  ("d1", "2018-09-20 10:01:00", "blahnew")
).toDF("documentId","timestamp","anotherField")

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"documentId").orderBy($"timestamp".desc)
val Resultdf = df.withColumn("rownum", row_number.over(w))
     .where($"rownum" === 1).drop("rownum")

Resultdf.show()

input:

+----------+-------------------+------------+
|documentId|          timestamp|anotherField|
+----------+-------------------+------------+
|        d1|2018-09-20 10:00:00|       blah1|
|        d2|2018-09-20 09:00:00|       blah2|
|        d1|2018-09-20 10:01:00|     blahnew|
+----------+-------------------+------------+

output:

+----------+-------------------+------------+
|documentId|          timestamp|anotherField|
+----------+-------------------+------------+
|        d2|2018-09-20 09:00:00|       blah2|
|        d1|2018-09-20 10:01:00|     blahnew|
+----------+-------------------+------------+

Upvotes: 8

Related Questions