Reputation: 1058
I am modifying the function described here to work with pyspark.
Input
from pyspark.sql import functions as F
data_in = spark.createDataFrame([
[1, "2017-1-1", "2017-6-30"], [1, "2017-1-1", "2017-1-3"], [1, "2017-5-1", "2017-9-30"],
[1, "2018-5-1", "2018-9-30"], [1, "2018-5-2", "2018-10-31"], [1, "2017-4-1", "2017-5-30"],
[1, "2017-10-3", "2017-10-3"], [1, "2016-12-5", "2016-12-31"], [1, "2016-12-1", "2016-12-2"],
[2, "2016-12-1", "2016-12-2"], [2, "2016-12-3", "2016-12-25"]
], schema=["id","start_dt","end_dt"])
data_in = data_in.select("id", F.to_date("start_dt","yyyy-M-d").alias("start_dt"),
F.to_date("end_dt","yyyy-M-d").alias("end_dt")).sort(["id","start_dt","end_dt"])
Aggregate function to apply
from datetime import datetime
mydt = datetime(1970,1,1).date()
def merge_dates(grp):
dt_groups = ((grp["start_dt"]-grp["end_dt"].shift(fill_value=mydt)).dt.days > 1).cumsum()
grouped = grp.groupby(dt_groups).agg({"start_dt":"min", "end_dt":"max"})
return grouped if len(grp)==len(grouped) else merge_dates(grouped)
Testing using Pandas
df = data_in.toPandas()
df.groupby("id").apply(merge_dates).reset_index().drop('level_1', axis=1)
Output
id start_dt end_dt
0 1 2016-12-01 2016-12-02
1 1 2016-12-05 2017-09-30
2 1 2017-10-03 2017-10-03
3 1 2018-05-01 2018-10-31
4 2 2016-12-01 2016-12-25
When I try to run this using Spark
data_out = data_in.groupby("id").applyInPandas(merge_dates, schema=data_in.schema)
display(data_out)
I get the following error
PythonException: 'RuntimeError: Number of columns of the returned pandas.DataFrame doesn't match specified schema. Expected: 3 Actual: 2'. Full traceback below:
When I change schema to data_in.schema[1:] I get back only the date columns which are computed correctly (matches the Pandas output) but does not return the field id
- which is obviously required. How can I fix this so that the final output has the id
as well?
Upvotes: 1
Views: 431
Reputation: 75080
With spark only, if we replicate what you have in pandas, it would look like below:
from pyspark.sql import functions as F
w = W.partitionBy("id").orderBy(F.monotonically_increasing_id())
w1 = w.rangeBetween(W.unboundedPreceding,0)
out = (data_in.withColumn("helper",F.datediff(F.col("start_dt"),
F.lag("end_dt").over(w))>1)
.fillna({"helper":True})
.withColumn("helper2",F.sum(F.col("helper").cast("int")).over(w1))
.groupBy("id","helper2").agg(F.min("start_dt").alias("start_dt"),
F.max("end_dt").alias("end_dt")
)
.drop("helper2"))
out.show()
+---+----------+----------+
| id| start_dt| end_dt|
+---+----------+----------+
| 1|2016-12-01|2016-12-02|
| 1|2016-12-05|2017-09-30|
| 1|2017-10-03|2017-10-03|
| 1|2018-05-01|2018-10-31|
| 2|2016-12-01|2016-12-25|
+---+----------+----------+
Note that this assumes that mydt = datetime(1970,1,1).date()
is just a placeholder for nulls when shifting the values, .i have used fillna as True for same. if not you can fillna right after the lag
which is the same as shift
Upvotes: 2