Reputation: 1087
I am trying to fill a series of observation on a spark dataframe. Basically I have a list of days and I should create the missing one for each group.
In pandas there is the reindex
function, which is not available in pyspark.
I tried to implement a pandas UDF:
@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def reindex_by_date(df):
df = df.set_index('dates')
dates = pd.date_range(df.index.min(), df.index.max())
return df.reindex(dates, fill_value=0).ffill()
This looks like should do what I need, however it fails with this message
AttributeError: Can only use .dt accessor with datetimelike values
. What am I doing wrong here?
Here the full code:
data = spark.createDataFrame(
[(1, "2020-01-01", 0),
(1, "2020-01-03", 42),
(2, "2020-01-01", -1),
(2, "2020-01-03", -2)],
('id', 'dates', 'value'))
data = data.withColumn('dates', col('dates').cast("date"))
schema = StructType([
StructField('id', IntegerType()),
StructField('dates', DateType()),
StructField('value', DoubleType())])
@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def reindex_by_date(df):
df = df.set_index('dates')
dates = pd.date_range(df.index.min(), df.index.max())
return df.reindex(dates, fill_value=0).ffill()
data = data.groupby('id').apply(reindex_by_date)
Ideally I would like something like this:
+---+----------+-----+
| id| dates|value|
+---+----------+-----+
| 1|2020-01-01| 0|
| 1|2020-01-02| 0|
| 1|2020-01-03| 42|
| 2|2020-01-01| -1|
| 2|2020-01-02| 0|
| 2|2020-01-03| -2|
+---+----------+-----+
Upvotes: 2
Views: 1096
Reputation: 932
I would try to reduce the content of the udf as much as possible. In this case I would only calculate the date range per ID in the udf. For the other parts I would use Spark native functions.
from pyspark.sql import types as T
from pyspark.sql import functions as F
# Get min and max date per ID
date_ranges = data.groupby('id').agg(F.min('dates').alias('date_min'), F.max('dates').alias('date_max'))
# Calculate the date range for each ID
@F.udf(returnType=T.ArrayType(T.DateType()))
def get_date_range(date_min, date_max):
return [t.date() for t in list(pd.date_range(date_min, date_max))]
# To get one row per potential date, we need to explode the UDF output
date_ranges = date_ranges.withColumn(
'dates',
F.explode(get_date_range(F.col('date_min'), F.col('date_max')))
)
date_ranges = date_ranges.drop('date_min', 'date_max')
# Add the value for existing entries and add 0 for others
result = date_ranges.join(
data,
['id', 'dates'],
'left'
)
result = result.fillna({'value': 0})
I think there is no need to use a UDF here. What you want to can be archived in a different way: First, you get all possible IDs and all necessary dates. Second, you crossJoin them, which will provide you with all possible combinations. Third, left join the original data onto the combinations. Fourth, replace the occurred null values with 0.
# Get all unique ids
ids_df = data.select('id').distinct()
# Get the date series
date_min, date_max = data.agg(F.min('dates'), F.max('dates')).collect()[0]
dates = [[t.date()] for t in list(pd.date_range(date_min, date_max))]
dates_df = spark.createDataFrame(data=dates, schema="dates:date")
# Calculate all combinations
all_comdinations = ids_df.crossJoin(dates_df)
# Add the value column
result = all_comdinations.join(
data,
['id', 'dates'],
'left'
)
# Replace all null values with 0
result = result.fillna({'value': 0})
Please be aware of the following limitiations with this solution:
[EDIT] Split into two cases as I first thought all IDs have the same date range.
Upvotes: 3