Florian
Florian

Reputation: 354

split a list of overlapping intervals into non overlapping subintervals in a pyspark dataframe and check if values are valid on overlapped intervals

I have a pyspark dataframe that contains the columns start_time, end_time that define an interval per row. It contains as well a column is_duplicated set to True if one interval is overlapped by at least another interval; set to False if not.

There is a column rate, and I want to know if there is not different values for a sub-interval (that is overlapped by definition); and if it is the case, I want to keep the record that contain the latest update contained in the column updated_at as the ground truth.

In the intermediary step, I was thinking to create a column is_validated set to:

Note: the intermediary step is not mandatory, I provided it just to make the explanation clearer.

Inputs:

# So this:
input_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-04 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'),  # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20          
              Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'),  # OVERLAP: full overlap for (2,3) with (1,4)               
              Row(start_time='2018-01-03 00:00:00', end_time='2018-01-05 00:00:00', rate=20, updated_at='2021-02-20 00:00:00'),  # OVERLAP: (3,5) and (1,4) and rate=10/20                          
              Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00'),  # NO OVERLAP: hole between (5,6)                                            
              Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00')]  # NO OVERLAP

df = spark.createDataFrame(input_rows)
df.show()
>>> +-------------------+-------------------+----+-------------------+
    |         start_time|           end_time|rate|         updated_at|
    +-------------------+-------------------+----+-------------------+
    |2018-01-01 00:00:00|2018-01-04 00:00:00|  10|2021-02-25 00:00:00|
    |2018-01-02 00:00:00|2018-01-03 00:00:00|  10|2021-02-25 00:00:00|
    |2018-01-03 00:00:00|2018-01-05 00:00:00|  20|2021-02-20 00:00:00|
    |2018-01-06 00:00:00|2018-01-07 00:00:00|  30|2021-02-25 00:00:00|
    |2018-01-07 00:00:00|2018-01-08 00:00:00|  30|2021-02-25 00:00:00|
    +-------------------+-------------------+----+-------------------+
# Will become:
tmp_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-02 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
            Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=True,  is_validated=True),
            Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=True,  is_validated=True),
            Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=10, updated_at='2021-02-20 00:00:00', is_duplicated=True,  is_validated=False),
            Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=20, updated_at='2021-02-25 00:00:00', is_duplicated=True,  is_validated=True),
            Row(start_time='2018-01-04 00:00:00', end_time='2018-01-05 00:00:00', rate=20, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
            Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
            Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None)
           ]
tmp_df = spark.createDataFrame(tmp_rows)
tmp_df.show()
>>> 
+-------------------+-------------------+----+-------------------+-------------+------------+
|         start_time|           end_time|rate|         updated_at|is_duplicated|is_validated|
+-------------------+-------------------+----+-------------------+-------------+------------+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|2021-02-25 00:00:00|        false|        null|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|2021-02-25 00:00:00|         true|        true|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|2021-02-25 00:00:00|         true|        true|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  10|2021-02-20 00:00:00|         true|       false|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  20|2021-02-25 00:00:00|         true|        true|
|2018-01-04 00:00:00|2018-01-05 00:00:00|  20|2021-02-25 00:00:00|        false|        null|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|2021-02-25 00:00:00|        false|        null|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|2021-02-25 00:00:00|        false|        null|
+-------------------+-------------------+----+-------------------+-------------+------------+

# To give you: 
output_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-02 00:00:00', rate=10),
               Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10),
               Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=20),
               Row(start_time='2018-01-04 00:00:00', end_time='2018-01-05 00:00:00', rate=20),
               Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30),
               Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30)
              ]
final_df = spark.createDataFrame(output_rows)
final_df.show()
>>> 
+-------------------+-------------------+----+
|         start_time|           end_time|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  10|
|2018-01-04 00:00:00|2018-01-05 00:00:00|  20|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
+-------------------+-------------------+----+

Upvotes: 1

Views: 1343

Answers (2)

Florian
Florian

Reputation: 354

This works:

from pyspark.sql import functions as F, Row, SparkSession, SQLContext, Window
from pyspark.sql.types import BooleanType

spark = (SparkSession.builder 
    .master("local") 
    .appName("Octopus") 
    .config('spark.sql.autoBroadcastJoinThreshold', -1)
    .getOrCreate())

input_rows = [Row(idx=0, interval_start='2018-01-01 00:00:00', interval_end='2018-01-04 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'),  # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20          
              Row(idx=0, interval_start='2018-01-02 00:00:00', interval_end='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'),  # OVERLAP: full overlap for (2,3) with (1,4)               
              Row(idx=0, interval_start='2018-01-03 00:00:00', interval_end='2018-01-05 00:00:00', rate=20, updated_at='2021-02-20 00:00:00'),  # OVERLAP: (3,5) and (1,4) and rate=10/20                          
              Row(idx=0, interval_start='2018-01-06 00:00:00', interval_end='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00'),  # NO OVERLAP: hole between (5,6)                                            
              Row(idx=0, interval_start='2018-01-07 00:00:00', interval_end='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00')]  # NO OVERLAP


df = spark.createDataFrame(input_rows)
df.show()

# Compute overlapping intervals
sc = spark.sparkContext
sql_context = SQLContext(sc, spark)

def overlap(start_first, end_first, start_second, end_second):
    return ((start_first < start_second < end_first) or (start_first < end_second < end_first)
           or (start_second < start_first < end_second) or (start_second < end_first < end_second))
sql_context.registerFunction('overlap', overlap, BooleanType())

df.registerTempTable("df1")
df.registerTempTable("df2")
df = df.cache()

overlap_df = spark.sql("""
     SELECT df1.idx, df1.interval_start, df1.interval_end, df1.rate AS rate FROM df1 JOIN df2
     ON df1.idx == df2.idx
     WHERE overlap(df1.interval_start, df1.interval_end, df2.interval_start, df2.interval_end)
""")
overlap_df = overlap_df.cache()

# Compute NON overlapping intervals
non_overlap_df = df.join(overlap_df, ['interval_start', 'interval_end'], 'leftanti')

# Stack overlapping points
interval_point = overlap_df.select('interval_start').union(overlap_df.select('interval_end'))
interval_point = interval_point.withColumnRenamed('interval_start', 'p').distinct().sort('p')

# Construct continuous overlapping intervals
w = Window.rowsBetween(1, Window.unboundedFollowing)

interval_point = interval_point.withColumn('interval_end', F.min('p').over(w)).dropna(subset=['p', 'interval_end'])
interval_point = interval_point.withColumnRenamed('p', 'interval_start')

# Stack continuous overlapping intervals and non overlapping intervals
df3 = interval_point.select('interval_start', 'interval_end').union(non_overlap_df.select('interval_start', 'interval_end'))

# Point in interval range join
# https://docs.databricks.com/delta/join-performance/range-join.html
df3.registerTempTable("df3")
df.registerTempTable("df")
sql = """SELECT df3.interval_start, df3.interval_end, df.rate, df.updated_at
         FROM df3 JOIN df ON df3.interval_start BETWEEN df.interval_start and df.interval_end - INTERVAL 1 seconds"""
df4 = spark.sql(sql)
df4.sort('interval_start').show()

# select non overlapped intervals and keep most up to date rate value for overlapping intervals
(df4.groupBy('interval_start', 'interval_end')
    .agg(F.max(F.struct('updated_at', 'rate'))['rate'].alias('rate'))
    .orderBy("interval_start")).show()

+-------------------+-------------------+----+
|     interval_start|       interval_end|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  10|
|2018-01-04 00:00:00|2018-01-05 00:00:00|  20|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
+-------------------+-------------------+----+

Upvotes: 1

mck
mck

Reputation: 42392

You can explode sequences of timestamps, just like your intermediate dataframe, and then group by the start and end times to get the latest rate according to the update time.

import pyspark.sql.functions as F

output = df.selectExpr(
    """
    inline(arrays_zip(
        sequence(timestamp(start_time), timestamp(end_time) - interval 1 day, interval 1 day),
        sequence(timestamp(start_time) + interval 1 day, timestamp(end_time), interval 1 day)
    )) as (start_time, end_time)
    """,
    "rate", "updated_at"
).groupBy(
    'start_time', 'end_time'
).agg(
    F.max(F.struct('updated_at', 'rate'))['rate'].alias('rate')
).orderBy("start_time")

output.show()
+-------------------+-------------------+----+
|         start_time|           end_time|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  10|
|2018-01-04 00:00:00|2018-01-05 00:00:00|  20|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
+-------------------+-------------------+----+

Upvotes: 0

Related Questions