Reputation: 5745
I have a DataFrame
with the following schema:
root
|- documentId
|- timestamp
|- anotherField
For example,
"d1", "2018-09-20 10:00:00", "blah1"
"d2", "2018-09-20 09:00:00", "blah2"
"d1", "2018-09-20 10:01:00", "blahnew"
Note that for the sake of understanding (and my convenience) I am showing the timestamp as a string. It is in fact a long
representing milliseconds since epoch.
As seen here, there are duplicate rows (row 1 and 3) with the same documentId
but different timestamp
(and possibly different other fields). I want to dedupe and retain only the most recent (based on timestamp
) row for each documentId
.
A simple df.groupBy("documentId").agg(max("timestamp), ...)
does not seem likely to work here because I don't know how to retain the other fields in the row corresponding to the one that satisfies max("timestamp")
.
So, I came up with a complicated way of doing this.
// first find the max timestamp corresponding to each documentId
val mostRecent = df
.select("documentId", "timestamp")
.groupBy("documentId")
.agg(max("timestamp"))
// now join with the original df on timestamp to retain
val dedupedDf = df.join(mostRecent, Seq("documentId", "timestamp"), "inner")
This resulting dedupedDf
should have only those rows which correspond to the most recent entry for each documentId
.
Although this works, I don't feel this is the right (or efficient) approach, since I am using a join
which seems needless.
How can I do it better? I am looking for pure "DataFrame" based solutions as opposed to RDD-based approaches (since DataBricks folks have repeatedly told us in a workshop to work with DataFrames and not RDDs).
Upvotes: 1
Views: 3320
Reputation: 662
See the below code helps your objective,
val df = Seq(
("d1", "2018-09-20 10:00:00", "blah1"),
("d2", "2018-09-20 09:00:00", "blah2"),
("d1", "2018-09-20 10:01:00", "blahnew")
).toDF("documentId","timestamp","anotherField")
import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy($"documentId").orderBy($"timestamp".desc)
val Resultdf = df.withColumn("rownum", row_number.over(w))
.where($"rownum" === 1).drop("rownum")
Resultdf.show()
input:
+----------+-------------------+------------+
|documentId| timestamp|anotherField|
+----------+-------------------+------------+
| d1|2018-09-20 10:00:00| blah1|
| d2|2018-09-20 09:00:00| blah2|
| d1|2018-09-20 10:01:00| blahnew|
+----------+-------------------+------------+
output:
+----------+-------------------+------------+
|documentId| timestamp|anotherField|
+----------+-------------------+------------+
| d2|2018-09-20 09:00:00| blah2|
| d1|2018-09-20 10:01:00| blahnew|
+----------+-------------------+------------+
Upvotes: 8