Reputation: 354
I have a pyspark dataframe that contains the columns start_time
, end_time
that define an interval per row. It contains as well a column is_duplicated
set to True
if one interval is overlapped by at least another interval; set to False
if not.
There is a column rate
, and I want to know if there is not different values for a sub-interval (that is overlapped by definition); and if it is the case, I want to keep the record that contain the latest update contained in the column updated_at
as the ground truth.
In the intermediary step, I was thinking to create a column is_validated
set to:
None
when the sub-interval is not overlapedTrue
when the sub-interval is overlapped by another one containing a different rate
value and is the last updatedFalse
when the sub-interval is overlapped by another one containing a different rate
value and is NOT the last updatedNote: the intermediary step is not mandatory, I provided it just to make the explanation clearer.
Inputs:
# So this:
input_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-04 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'), # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20
Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'), # OVERLAP: full overlap for (2,3) with (1,4)
Row(start_time='2018-01-03 00:00:00', end_time='2018-01-05 00:00:00', rate=20, updated_at='2021-02-20 00:00:00'), # OVERLAP: (3,5) and (1,4) and rate=10/20
Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00'), # NO OVERLAP: hole between (5,6)
Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00')] # NO OVERLAP
df = spark.createDataFrame(input_rows)
df.show()
>>> +-------------------+-------------------+----+-------------------+
| start_time| end_time|rate| updated_at|
+-------------------+-------------------+----+-------------------+
|2018-01-01 00:00:00|2018-01-04 00:00:00| 10|2021-02-25 00:00:00|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|2021-02-25 00:00:00|
|2018-01-03 00:00:00|2018-01-05 00:00:00| 20|2021-02-20 00:00:00|
|2018-01-06 00:00:00|2018-01-07 00:00:00| 30|2021-02-25 00:00:00|
|2018-01-07 00:00:00|2018-01-08 00:00:00| 30|2021-02-25 00:00:00|
+-------------------+-------------------+----+-------------------+
# Will become:
tmp_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-02 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=True, is_validated=True),
Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=True, is_validated=True),
Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=10, updated_at='2021-02-20 00:00:00', is_duplicated=True, is_validated=False),
Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=20, updated_at='2021-02-25 00:00:00', is_duplicated=True, is_validated=True),
Row(start_time='2018-01-04 00:00:00', end_time='2018-01-05 00:00:00', rate=20, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None)
]
tmp_df = spark.createDataFrame(tmp_rows)
tmp_df.show()
>>>
+-------------------+-------------------+----+-------------------+-------------+------------+
| start_time| end_time|rate| updated_at|is_duplicated|is_validated|
+-------------------+-------------------+----+-------------------+-------------+------------+
|2018-01-01 00:00:00|2018-01-02 00:00:00| 10|2021-02-25 00:00:00| false| null|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|2021-02-25 00:00:00| true| true|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|2021-02-25 00:00:00| true| true|
|2018-01-03 00:00:00|2018-01-04 00:00:00| 10|2021-02-20 00:00:00| true| false|
|2018-01-03 00:00:00|2018-01-04 00:00:00| 20|2021-02-25 00:00:00| true| true|
|2018-01-04 00:00:00|2018-01-05 00:00:00| 20|2021-02-25 00:00:00| false| null|
|2018-01-06 00:00:00|2018-01-07 00:00:00| 30|2021-02-25 00:00:00| false| null|
|2018-01-07 00:00:00|2018-01-08 00:00:00| 30|2021-02-25 00:00:00| false| null|
+-------------------+-------------------+----+-------------------+-------------+------------+
# To give you:
output_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-02 00:00:00', rate=10),
Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10),
Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=20),
Row(start_time='2018-01-04 00:00:00', end_time='2018-01-05 00:00:00', rate=20),
Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30),
Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30)
]
final_df = spark.createDataFrame(output_rows)
final_df.show()
>>>
+-------------------+-------------------+----+
| start_time| end_time|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00| 10|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|
|2018-01-03 00:00:00|2018-01-04 00:00:00| 10|
|2018-01-04 00:00:00|2018-01-05 00:00:00| 20|
|2018-01-06 00:00:00|2018-01-07 00:00:00| 30|
|2018-01-07 00:00:00|2018-01-08 00:00:00| 30|
+-------------------+-------------------+----+
Upvotes: 1
Views: 1343
Reputation: 354
This works:
from pyspark.sql import functions as F, Row, SparkSession, SQLContext, Window
from pyspark.sql.types import BooleanType
spark = (SparkSession.builder
.master("local")
.appName("Octopus")
.config('spark.sql.autoBroadcastJoinThreshold', -1)
.getOrCreate())
input_rows = [Row(idx=0, interval_start='2018-01-01 00:00:00', interval_end='2018-01-04 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'), # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20
Row(idx=0, interval_start='2018-01-02 00:00:00', interval_end='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'), # OVERLAP: full overlap for (2,3) with (1,4)
Row(idx=0, interval_start='2018-01-03 00:00:00', interval_end='2018-01-05 00:00:00', rate=20, updated_at='2021-02-20 00:00:00'), # OVERLAP: (3,5) and (1,4) and rate=10/20
Row(idx=0, interval_start='2018-01-06 00:00:00', interval_end='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00'), # NO OVERLAP: hole between (5,6)
Row(idx=0, interval_start='2018-01-07 00:00:00', interval_end='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00')] # NO OVERLAP
df = spark.createDataFrame(input_rows)
df.show()
# Compute overlapping intervals
sc = spark.sparkContext
sql_context = SQLContext(sc, spark)
def overlap(start_first, end_first, start_second, end_second):
return ((start_first < start_second < end_first) or (start_first < end_second < end_first)
or (start_second < start_first < end_second) or (start_second < end_first < end_second))
sql_context.registerFunction('overlap', overlap, BooleanType())
df.registerTempTable("df1")
df.registerTempTable("df2")
df = df.cache()
overlap_df = spark.sql("""
SELECT df1.idx, df1.interval_start, df1.interval_end, df1.rate AS rate FROM df1 JOIN df2
ON df1.idx == df2.idx
WHERE overlap(df1.interval_start, df1.interval_end, df2.interval_start, df2.interval_end)
""")
overlap_df = overlap_df.cache()
# Compute NON overlapping intervals
non_overlap_df = df.join(overlap_df, ['interval_start', 'interval_end'], 'leftanti')
# Stack overlapping points
interval_point = overlap_df.select('interval_start').union(overlap_df.select('interval_end'))
interval_point = interval_point.withColumnRenamed('interval_start', 'p').distinct().sort('p')
# Construct continuous overlapping intervals
w = Window.rowsBetween(1, Window.unboundedFollowing)
interval_point = interval_point.withColumn('interval_end', F.min('p').over(w)).dropna(subset=['p', 'interval_end'])
interval_point = interval_point.withColumnRenamed('p', 'interval_start')
# Stack continuous overlapping intervals and non overlapping intervals
df3 = interval_point.select('interval_start', 'interval_end').union(non_overlap_df.select('interval_start', 'interval_end'))
# Point in interval range join
# https://docs.databricks.com/delta/join-performance/range-join.html
df3.registerTempTable("df3")
df.registerTempTable("df")
sql = """SELECT df3.interval_start, df3.interval_end, df.rate, df.updated_at
FROM df3 JOIN df ON df3.interval_start BETWEEN df.interval_start and df.interval_end - INTERVAL 1 seconds"""
df4 = spark.sql(sql)
df4.sort('interval_start').show()
# select non overlapped intervals and keep most up to date rate value for overlapping intervals
(df4.groupBy('interval_start', 'interval_end')
.agg(F.max(F.struct('updated_at', 'rate'))['rate'].alias('rate'))
.orderBy("interval_start")).show()
+-------------------+-------------------+----+
| interval_start| interval_end|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00| 10|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|
|2018-01-03 00:00:00|2018-01-04 00:00:00| 10|
|2018-01-04 00:00:00|2018-01-05 00:00:00| 20|
|2018-01-06 00:00:00|2018-01-07 00:00:00| 30|
|2018-01-07 00:00:00|2018-01-08 00:00:00| 30|
+-------------------+-------------------+----+
Upvotes: 1
Reputation: 42392
You can explode sequences of timestamps, just like your intermediate dataframe, and then group by the start and end times to get the latest rate according to the update time.
import pyspark.sql.functions as F
output = df.selectExpr(
"""
inline(arrays_zip(
sequence(timestamp(start_time), timestamp(end_time) - interval 1 day, interval 1 day),
sequence(timestamp(start_time) + interval 1 day, timestamp(end_time), interval 1 day)
)) as (start_time, end_time)
""",
"rate", "updated_at"
).groupBy(
'start_time', 'end_time'
).agg(
F.max(F.struct('updated_at', 'rate'))['rate'].alias('rate')
).orderBy("start_time")
output.show()
+-------------------+-------------------+----+
| start_time| end_time|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00| 10|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|
|2018-01-03 00:00:00|2018-01-04 00:00:00| 10|
|2018-01-04 00:00:00|2018-01-05 00:00:00| 20|
|2018-01-06 00:00:00|2018-01-07 00:00:00| 30|
|2018-01-07 00:00:00|2018-01-08 00:00:00| 30|
+-------------------+-------------------+----+
Upvotes: 0