Reputation: 118
I have a DataFrame that contains a person, a weight and and timestamp as such:
+-----------+-------------------+------+
| person| timestamp|weight|
+-----------+-------------------+------+
| 1|2019-12-02 14:54:17| 49.94|
| 1|2019-12-03 08:58:39| 50.49|
| 1|2019-12-06 10:44:01| 50.24|
| 2|2019-12-02 08:58:39| 62.32|
| 2|2019-12-04 10:44:01| 65.64|
+-----------+-------------------+------+
I want to fill such that every person has something for every date, meaning that the above should be:
+-----------+-------------------+------+
| person| timestamp|weight|
+-----------+-------------------+------+
| 1|2019-12-02 14:54:17| 49.94|
| 1|2019-12-03 08:58:39| 50.49|
| 1|2019-12-04 00:00:01| 50.49|
| 1|2019-12-05 00:00:01| 50.49|
| 1|2019-12-06 10:44:01| 50.24|
| 1|2019-12-07 00:00:01| 50.24|
| 1|2019-12-08 00:00:01| 50.24|
| 2|2019-12-02 08:58:39| 62.32|
| 2|2019-12-03 00:00:01| 62.32|
| 2|2019-12-04 10:44:01| 65.64|
| 2|2019-12-05 00:00:01| 65.64|
| 2|2019-12-06 00:00:01| 65.64|
| 2|2019-12-07 00:00:01| 65.64|
| 2|2019-12-08 00:00:01| 65.64|
+-----------+-------------------+------+
I have defined a new table that use datediff to contain all dates between the min and max date:
min_max_date = df_person_weights.select(min("timestamp"), max("timestamp")) \
.withColumnRenamed("min(timestamp)", "min_date") \
.withColumnRenamed("max(timestamp)", "max_date")
min_max_date = min_max_date.withColumn("datediff", datediff("max_date", "min_date")) \
.withColumn("repeat", expr("split(repeat(',', datediff), ',')")) \
.select("*", posexplode("repeat").alias("date", "val")) \
.withColumn("date", expr("date_add(min_date, date)"))
This gives me a new DataFrame that contains the dates like:
+----------+
| date|
+----------+
|2019-12-03|
|2019-12-03|
|2019-12-04|
|2019-12-05|
|2019-12-06|
|2019-12-07|
|2019-12-08|
+----------+
I have tried different joins like:
min_max_date.join(df_price_history, min_max_date.date != df_price_history.date, "leftouter")
But I'm not getting the results that I need, can someone help with this? How do I merge the information I have now?
Upvotes: 3
Views: 4997
Reputation: 13459
You’re looking to forward-fill a dataset. This is being made a bit more complex because you need to do it per category (person).
One way to do it would be like this: create a new DataFrame that has all the dates you want to have a value for, per person (see below, this is just dates_by_person
).
Then, left-join the original DataFrame to this one, so you start creating the missing rows.
Next, use a windowing function to find in each group of person
, sorted by the date, the last non-null weight. In case you can have multiple entries per date (so one person has multiple filled in records on one specific date), you must also order by the timestamp column.
Finally you coalesce the columns, so that any null-field gets replaced by the intended value.
from datetime import datetime, timedelta
from itertools import product
import pyspark.sql.functions as psf
from pyspark.sql import Window
data = ( # recreate the DataFrame
(1, datetime(2019, 12, 2, 14, 54, 17), 49.94),
(1, datetime(2019, 12, 3, 8, 58, 39), 50.49),
(1, datetime(2019, 12, 6, 10, 44, 1), 50.24),
(2, datetime(2019, 12, 2, 8, 58, 39), 62.32),
(2, datetime(2019, 12, 4, 10, 44, 1), 65.64))
df = spark.createDataFrame(data, schema=("person", "timestamp", "weight"))
min_max_timestamps = df.agg(psf.min(df.timestamp), psf.max(df.timestamp)).head()
first_date, last_date = [ts.date() for ts in min_max_timestamps]
all_days_in_range = [first_date + timedelta(days=d)
for d in range((last_date - first_date).days + 1)]
people = [row.person for row in df.select("person").distinct().collect()]
dates_by_person = spark.createDataFrame(product(people, all_days_in_range),
schema=("person", "date"))
df2 = (dates_by_person.join(df,
(psf.to_date(df.timestamp) == dates_by_person.date)
& (dates_by_person.person == df.person),
how="left")
.drop(df.person)
)
wind = (Window
.partitionBy("person")
.rangeBetween(Window.unboundedPreceding, -1)
.orderBy(psf.unix_timestamp("date"))
)
df3 = df2.withColumn("last_weight",
psf.last("weight", ignorenulls=True).over(wind))
df4 = df3.select(
df3.person,
psf.coalesce(df3.timestamp, psf.to_timestamp(df3.date)).alias("timestamp"),
psf.coalesce(df3.weight, df3.last_weight).alias("weight"))
df4.show()
# +------+-------------------+------+
# |person| timestamp|weight|
# +------+-------------------+------+
# | 1|2019-12-02 14:54:17| 49.94|
# | 1|2019-12-03 08:58:39| 50.49|
# | 1|2019-12-04 00:00:00| 50.49|
# | 1|2019-12-05 00:00:00| 50.49|
# | 1|2019-12-06 10:44:01| 50.24|
# | 2|2019-12-02 08:58:39| 62.32|
# | 2|2019-12-03 00:00:00| 62.32|
# | 2|2019-12-04 10:44:01| 65.64|
# | 2|2019-12-05 00:00:00| 65.64|
# | 2|2019-12-06 00:00:00| 65.64|
# +------+-------------------+------+
Edit: as suggested by David in the comments, if you have a very large number of people, the construction of dates_by_people
can be done in a way that doesn’t require bringing everything to the driver. In this example, we’re talking about a small number of integers, nothing big. But if it gets big, try:
dates = spark.createDataFrame(((d,) for d in all_days_in_range),
schema=("date",))
people = df.select("person").distinct()
dates_by_person = dates.crossJoin(people)
Upvotes: 3