Reputation: 181
I have the following dataframe in Spark.
id | start | end |
---|---|---|
140396 | 2002-06-18 | 2003-06-18 |
140396 | 2007-07-29 | 2015-07-29 |
140396 | 2008-02-05 | 2010-02-05 |
140396 | 2009-01-18 | 2010-01-18 |
140396 | 2013-01-19 | 2021-08-30 |
140396 | 2017-05-15 | 2021-08-30 |
I have to analyze the date ranges to get other date ranges without intersections between them, but keeping the full range of dates. Result:
id | start | end |
---|---|---|
140396 | 2002-06-18 | 2003-06-18 |
140396 | 2007-07-29 | 2021-08-30 |
Other example could from:
id | start | end |
---|---|---|
140396 | 2002-06-18 | 2003-06-18 |
140396 | 2007-07-29 | 2015-07-29 |
140396 | 2014-02-05 | 2016-02-05 |
140396 | 2017-05-15 | 2021-08-30 |
to
id | start | end |
---|---|---|
140396 | 2002-06-18 | 2003-06-18 |
140396 | 2007-07-29 | 2016-02-05 |
140396 | 2017-05-15 | 2021-08-30 |
Keep in mind that there will be other users with their dates, so the problem is partitioned with a window by id.
Would someone wise know how to solve this?
Thank you very much in advance
Upvotes: 1
Views: 467
Reputation: 5078
You can group by your dataframe by id
, cast the result to a KeyValueGroupedDataset
and then perform merge overlapping intervals algorithm
You first need to define case class
representing a line of your dataframe:
case class Line(id: Int, start: String, end: String)
And then use it in the main part:
import sparkSession.implicits._
dataframe
.groupBy("id")
.as[Int, Line]
.flatMapGroups((id, grouped) => grouped.toSeq.sortBy(_.start).foldLeft(Seq.empty[Line])((acc, line) => (acc, line) match {
case (Nil, line) => Seq(line)
case (x::xs, line) if x.end >= line.end => x::xs
case (x::xs, line) if x.end < line.end && x.end >= line.start => Line(id, x.start, line.end) +: xs
case (xs, line) => line +: xs
}))
.orderBy("id", "start")
Upvotes: 1
Reputation: 494
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val df = Seq(
(140396, "2002-06-18", "2003-06-18"),
(140396, "2007-07-29", "2015-07-29"),
(140396, "2008-02-05", "2010-02-05"),
(140396, "2009-01-18", "2010-01-18"),
(140396, "2013-01-19", "2021-08-30"),
(140396, "2017-05-15", "2021-08-30"),
(140397, "2002-06-18", "2003-06-18"),
(140397, "2007-07-29", "2015-07-29"),
(140397, "2014-02-05", "2016-02-05"),
(140397, "2017-05-15", "2021-08-30")
) toDF ("id", "start", "end")
val windowSpec1 = Window.partitionBy(col("id"))
.orderBy(col("start"), col("end"))
.rowsBetween(Window.unboundedPreceding, Window.currentRow - 1)
val windowSpec2 = Window.partitionBy(col("id"))
.orderBy(col("start"), col("end"))
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
val result = df.withColumn("left_edge", max(when(col("start") < max(col("end")).over(windowSpec1), null).otherwise(col("start"))).over(windowSpec2))
.groupBy(col("id"), col("left_edge"))
.agg(min(col("start")).alias("start"), max(col("end")).alias("end"))
.select("id", "start", "end")
.orderBy("id", "start")
display(result)
Result:
id | start | end |
---|---|---|
140396 | 2002-06-18 | 2003-06-18 |
140396 | 2007-07-29 | 2021-08-30 |
140397 | 2002-06-18 | 2003-06-18 |
140397 | 2007-07-29 | 2016-02-05 |
140397 | 2017-05-15 | 2021-08-30 |
Reference: https://wiki.postgresql.org/wiki/Range_aggregation
Upvotes: 3