ironv
ironv

Reputation: 1058

pyspark pandas UDF merging date ranges in multiple rows

I am modifying the function described here to work with pyspark.

Input

from pyspark.sql import functions as F

data_in = spark.createDataFrame([
    [1, "2017-1-1", "2017-6-30"], [1, "2017-1-1", "2017-1-3"], [1, "2017-5-1", "2017-9-30"],
    [1, "2018-5-1", "2018-9-30"], [1, "2018-5-2", "2018-10-31"], [1, "2017-4-1", "2017-5-30"],
    [1, "2017-10-3", "2017-10-3"], [1, "2016-12-5", "2016-12-31"], [1, "2016-12-1", "2016-12-2"],
    [2, "2016-12-1", "2016-12-2"], [2, "2016-12-3", "2016-12-25"]
  ], schema=["id","start_dt","end_dt"])

data_in = data_in.select("id", F.to_date("start_dt","yyyy-M-d").alias("start_dt"), 
               F.to_date("end_dt","yyyy-M-d").alias("end_dt")).sort(["id","start_dt","end_dt"])

Aggregate function to apply

from datetime import datetime

mydt = datetime(1970,1,1).date()
def merge_dates(grp):
  dt_groups = ((grp["start_dt"]-grp["end_dt"].shift(fill_value=mydt)).dt.days > 1).cumsum()
  grouped = grp.groupby(dt_groups).agg({"start_dt":"min", "end_dt":"max"})
  return grouped if len(grp)==len(grouped) else merge_dates(grouped)

Testing using Pandas

df = data_in.toPandas()
df.groupby("id").apply(merge_dates).reset_index().drop('level_1', axis=1)

Output

   id    start_dt      end_dt
0   1  2016-12-01  2016-12-02
1   1  2016-12-05  2017-09-30
2   1  2017-10-03  2017-10-03
3   1  2018-05-01  2018-10-31
4   2  2016-12-01  2016-12-25

When I try to run this using Spark

data_out = data_in.groupby("id").applyInPandas(merge_dates, schema=data_in.schema)
display(data_out)

I get the following error

PythonException: 'RuntimeError: Number of columns of the returned pandas.DataFrame doesn't match specified schema. Expected: 3 Actual: 2'. Full traceback below:

When I change schema to data_in.schema[1:] I get back only the date columns which are computed correctly (matches the Pandas output) but does not return the field id - which is obviously required. How can I fix this so that the final output has the id as well?

Upvotes: 1

Views: 431

Answers (1)

anky
anky

Reputation: 75080

With spark only, if we replicate what you have in pandas, it would look like below:

from pyspark.sql import functions as F
w = W.partitionBy("id").orderBy(F.monotonically_increasing_id())
w1 = w.rangeBetween(W.unboundedPreceding,0)

out = (data_in.withColumn("helper",F.datediff(F.col("start_dt"),
                                    F.lag("end_dt").over(w))>1)
     .fillna({"helper":True})
     .withColumn("helper2",F.sum(F.col("helper").cast("int")).over(w1))
     .groupBy("id","helper2").agg(F.min("start_dt").alias("start_dt"),
                    F.max("end_dt").alias("end_dt")
                    )
.drop("helper2"))

out.show()

+---+----------+----------+
| id|  start_dt|    end_dt|
+---+----------+----------+
|  1|2016-12-01|2016-12-02|
|  1|2016-12-05|2017-09-30|
|  1|2017-10-03|2017-10-03|
|  1|2018-05-01|2018-10-31|
|  2|2016-12-01|2016-12-25|
+---+----------+----------+

Note that this assumes that mydt = datetime(1970,1,1).date() is just a placeholder for nulls when shifting the values, .i have used fillna as True for same. if not you can fillna right after the lag which is the same as shift

Upvotes: 2

Related Questions