Reputation: 73
I want to groupby aggregate a pyspark dataframe, while removing duplicates (keep last value) based on another column of this dataframe.
In summary, I would like to apply a dropDuplicates to a GroupedData object. So, for each group, I could keep only one row by some column, dynamically.
The straight forward group aggregation, for the dataframe bellow, would be:
from pyspark.sql import functions
dataframe = spark.createDataFrame(
[
(1, "2020-01-01", 1, 1),
(2, "2020-01-01", 2, 1),
(3, "2020-01-02", 1, 1),
(2, "2020-01-02", 1, 1)
],
("id", "ts", "feature", "h3")
).withColumn("ts", functions.col("ts").cast("timestamp"))
# +---+-------------------+-------+---+
# | id| ts|feature| h3|
# +---+-------------------+-------+---+
# | 1|2020-01-01 00:00:00| 1| 1|
# | 2|2020-01-01 00:00:00| 2| 1|
# | 3|2020-01-02 00:00:00| 1| 1|
# | 2|2020-01-02 00:00:00| 1| 1|
# +---+-------------------+-------+---+
aggregated = dataframe.groupby("h3",
functions.window(
timeColumn="ts",
windowDuration="3 days",
slideDuration="1 day",
)
).agg(
functions.sum("feature")
)
aggregated.show(truncate=False)
resulting in the following dataframe:
+---+------------------------------------------+------------+
|h3 |window |sum(feature)|
+---+------------------------------------------+------------+
|1 |[2019-12-30 00:00:00, 2020-01-02 00:00:00]|3 |
|1 |[2019-12-31 00:00:00, 2020-01-03 00:00:00]|5 |
|1 |[2020-01-01 00:00:00, 2020-01-04 00:00:00]|5 |
|1 |[2020-01-02 00:00:00, 2020-01-05 00:00:00]|2 |
+---+------------------------------------------+------------+
I want the aggregation to use only the latest state of each id
. In this case, id=2
have been updated to feature=1
at ts=2020-01-02 00:00:00
, so all aggregations with base timestamp bigger than 2020-01-02 00:00:00
should use only this state for column feature when id=2
. The expected aggregated dataframe is:
+---+------------------------------------------+------------+
|h3 |window |sum(feature)|
+---+------------------------------------------+------------+
|1 |[2019-12-30 00:00:00, 2020-01-02 00:00:00]|3 |
|1 |[2019-12-31 00:00:00, 2020-01-03 00:00:00]|3 |
|1 |[2020-01-01 00:00:00, 2020-01-04 00:00:00]|3 |
|1 |[2020-01-02 00:00:00, 2020-01-05 00:00:00]|2 |
+---+------------------------------------------+------------+
How can I do this with pyspark?
I have assumed that a MapType variable should not have duplicate keys in Spark. With that assumption, I thought I could aggregate the column creating a map id -> feature
and then just aggregate the map values with sum (or whatever the final aggregation should be).
So I did:
aggregated = dataframe.groupby("h3",
functions.window(
timeColumn="ts",
windowDuration="3 days",
slideDuration="1 day",
)
).agg(
functions.map_from_entries(
functions.collect_list(
functions.struct("id","feature")
)
).alias("id_feature")
)
aggregated.show(truncate=False)
But then I've found that maps can have duplicate keys:
+---+------------------------------------------+--------------------------------+
|h3 |window |id_feature |
+---+------------------------------------------+--------------------------------+
|1 |[2020-01-01 00:00:00, 2020-01-04 00:00:00]|[1 -> 1, 2 -> 2, 3 -> 1, 2 -> 1]|
|1 |[2019-12-31 00:00:00, 2020-01-03 00:00:00]|[1 -> 1, 2 -> 2, 3 -> 1, 2 -> 1]|
|1 |[2019-12-30 00:00:00, 2020-01-02 00:00:00]|[1 -> 1, 2 -> 2] |
|1 |[2020-01-02 00:00:00, 2020-01-05 00:00:00]|[3 -> 1, 2 -> 1] |
+---+------------------------------------------+--------------------------------+
so it doesn't solve my problem. Instead, I just found another problem. When using the display function in a Databricks' notebook, it shows the MapType column without duplicated keys.
Upvotes: 6
Views: 2561
Reputation: 13551
First, you can find the latest record for each id and time window and then join with the original dataframe with the latest records.
time_window = window(timeColumn="ts", windowDuration="3 days", slideDuration="1 day")
df2 = df.groupBy("h3", time_window, "id").agg(max("ts").alias("latest"))
df2.alias("a").join(df.alias("b"), (col("a.id") == col("b.id")) & (col("a.latest") == col("b.ts")), "left") \
.select("a.*", "feature") \
.groupBy("h3", "window") \
.agg(sum("feature")) \
.orderBy("window") \
.show(truncate=False)
Then, the result is the same as your expected one.
+---+------------------------------------------+------------+
|h3 |window |sum(feature)|
+---+------------------------------------------+------------+
|1 |[2019-12-29 00:00:00, 2020-01-01 00:00:00]|3 |
|1 |[2019-12-30 00:00:00, 2020-01-02 00:00:00]|3 |
|1 |[2019-12-31 00:00:00, 2020-01-03 00:00:00]|3 |
|1 |[2020-01-01 00:00:00, 2020-01-04 00:00:00]|2 |
+---+------------------------------------------+------------+
Upvotes: 1
Reputation: 13998
Since you are using Spark 2.4+, one way you can try is to use Spark SQL aggregate function, see below:
aggregated = dataframe.groupby("h3",
functions.window(
timeColumn="ts",
windowDuration="3 days",
slideDuration="1 day",
)
).agg(
functions.sort_array(functions.collect_list(
functions.struct("ts", "id", "feature")
), False).alias("id_feature")
)
I added ts
field into the resulting array of structs from functions.collect_list. use functions.sort_array to sort the list by ts
in descending order(to keep the latest record if duplicate exists). In the following aggregate function, we set the zero_value using a named_struct containing two fields: ids (MapType) to cache all processed id and total to do the sum only when the new id not exist in the cached ids
.
aggregated.selectExpr("h3", "window", """
aggregate(
id_feature,
/* zero_value */
(map() as ids, 0L as total),
/* merge */
(acc, y) -> named_struct(
/* add y.id into the ids map */
'ids', map_concat(acc.ids, map(y.id,1)),
/* sum to total only when y.id doesn't exist in acc.ids map */
'total', acc.total + IF(acc.ids[y.id] is null,y.feature,0)
),
/* finish, take only acc.total, discard acc.ids map */
acc -> acc.total
) as id_features
""").show()
+---+--------------------+----------+
| h3| window|id_feature|
+---+--------------------+----------+
| 1|[2020-01-01 00:00...| 3|
| 1|[2019-12-31 00:00...| 3|
| 1|[2019-12-30 00:00...| 3|
| 1|[2020-01-02 00:00...| 2|
+---+--------------------+----------+
Upvotes: 1