tomruarol
tomruarol

Reputation: 69

Pyspark window function to calculate number of transits between stops

I am using Pyspark and I would like to create a function which performs the following operation:

Given data describing the transactions of train users:

+----+-------------------+-------+---------------+-------------+-------------+
|USER|       DATE        |LINE_ID|      STOP     | TOPOLOGY_ID |TRANSPORT_ID |
+------------------------+-------+---------------+-------------+-------------+
|John|2021-01-27 07:27:34|      7| King Cross    |       171235|       03    |
|John|2021-01-27 07:28:00|     40| White Chapell |       123582|       03    |  
|John|2021-01-27 07:35:30|      4| Reaven        |       171565|       03    |  
|Tom |2021-01-27 07:27:23|      7| King Cross    |       171235|       03    |    
|Tom |2021-01-27 07:28:30|     40| White Chapell |       123582|       03    |                   
+----+-------------------+-------+---------------+-------------+-------------+

I would like to get the number of times a combination of stops A-B, B-C, etc. have been made in a grouped of 30 minutes.

So, let's say user "John" goes from stop "King Cross" to "White Chapell" at 7:27 and then goes from "White Chapell" to "Reaven" at 7:35.
Meanwhile, "Tom" goes from "King Cross" to "White Chapell" at 7:27 and then from "White Chapell" to "Oxford Circus" at 7:32.

The result of the opration would hae to be something like:

+----------------------+-----------------+---------------+-----------+
|          DATE        |   ORIG_STOP     |   DEST_STOP   | NUM_TRANS |
+----------------------+-----------------+---------------+-----------+
|   2021-01-27 07:00:00|  King Cross     | White Chapell |       2   |
|   2021-01-27 07:30:00|  White Chapell  | Reaven        |       1   |              
+----------------------+-----------------+---------------+-----------+

I have tried using window functions, but I can't manage to get what I really want.

Upvotes: 2

Views: 170

Answers (1)

ggordon
ggordon

Reputation: 10035

You may try running the following

Using Spark SQL

Within the first CTE initial_stop_groups it determines the related ORIGIN and DESTINATION stops and times with the LEAD function. The next CTE stop_groups, determines the associated 30 minute intervals using CASE expressions and date functions and filters out non-groups (i.e. no stop destinations). The final projection then uses a group by to aggregrate on the time interval, origin and destination groups to count the resulting NUM_TRANS where there are within the same 30 minute interval.

Assuming your data is in input_df

input_df.createOrReplaceTempView("input_df")

output_df = sparkSession.sql("""
 WITH initial_stop_groups AS (
        SELECT
            DATE as ORIG_DATE,
            LEAD(DATE) OVER (
                PARTITION BY USER,TRANSPORT_ID
                ORDER BY DATE
            ) as STOP_DATE,
            STOP as ORIG_STOP,
            LEAD(STOP) OVER (
                PARTITION BY USER,TRANSPORT_ID
                ORDER BY DATE
            ) as DEST_STOP
        FROM
            input_df
    ),
    stop_groups AS (
        SELECT 
            CAST(CONCAT(
              CAST(ORIG_DATE as DATE),
              ' ',
              hour(ORIG_DATE),
              ':',
              CASE WHEN minute(ORIG_DATE) < 30 THEN '00' ELSE '30' END,
              ':00'
            ) AS TIMESTAMP) as ORIG_TIME,
            CASE WHEN STOP_DATE IS NOT NULL THEN CAST(CONCAT(
              CAST(STOP_DATE as DATE),
              ' ',
              hour(STOP_DATE),
              ':',
              CASE WHEN minute(STOP_DATE) < 30 THEN '00' ELSE '30' END,
              ':00'
            ) AS TIMESTAMP) ELSE NULL END as STOP_TIME,
            ORIG_STOP,
            DEST_STOP
        FROM 
            initial_stop_groups
        WHERE
            DEST_STOP IS NOT NULL
    )
    SELECT
        STOP_TIME as DATE, 
        ORIG_STOP,
        DEST_STOP,
        COUNT(1) as NUM_TRANS
    FROM
        stop_groups
    WHERE
        (unix_timestamp(STOP_TIME) - unix_timestamp(ORIG_TIME)) <=30*60
        
    GROUP BY
        STOP_TIME, ORIG_STOP, DEST_STOP;
    
""")

output_df.show()
DATE orig_stop dest_stop num_trans
2021-01-27T07:00:00.000Z King Cross White Chapell 2
2021-01-27T07:30:00.000Z White Chapell Reaven 1

View on DB Fiddle

  • CAST((STOP_TIME - ORIG_TIME) as STRING) IN ('0 seconds','30 minutes') was replaced by (unix_timestamp(STOP_TIME) - unix_timestamp(ORIG_TIME)) <=30*60

Using spark API

Actual code

from pyspark.sql import functions as F
from pyspark.sql import Window

next_stop_window = Window().partitionBy("USER","TRANSPORT_ID").orderBy("DATE")

output_df = (
    input_df.select(
        F.col("DATE").alias("ORIG_DATE"),
        F.lead("DATE").over(next_stop_window).alias("STOP_DATE"),
        F.col("STOP").alias("ORIG_STOP"),
        F.lead("STOP").over(next_stop_window).alias("DEST_STOP"),
    ).where(
        F.col("DEST_STOP").isNotNull()
    ).select(
        F.concat(
            F.col("ORIG_DATE").cast("DATE"),
            F.lit(' '),
            F.hour("ORIG_DATE"),
            F.lit(':'),
            F.when(
                F.minute("ORIG_DATE") < 30, '00'
            ).otherwise('30'),
            F.lit(':00')
        ).cast("TIMESTAMP").alias("ORIG_TIME"),
        F.concat(
            F.col("STOP_DATE").cast("DATE"),
            F.lit(' '),
            F.hour("STOP_DATE"),
            F.lit(':'),
            F.when(
                F.minute("STOP_DATE") < 30, '00'
            ).otherwise('30'),
            F.lit(':00')
        ).cast("TIMESTAMP").alias("STOP_TIME"),
        F.col("ORIG_STOP"),
        F.col("DEST_STOP")
    ).where(
        (F.unix_timestamp("STOP_TIME") - F.unix_timestamp("ORIG_TIME")) <= 30*60
        # (F.col("STOP_TIME")-F.col("ORIG_TIME")).cast("STRING").isin(['0 seconds','30 minutes'])
    ).groupBy(
        F.col("STOP_TIME"),
        F.col("ORIG_STOP"),
        F.col("DEST_STOP"),
    ).count().select(
        F.col("STOP_TIME").alias("DATE"),
        F.col("ORIG_STOP"),
        F.col("DEST_STOP"),
        F.col("count").alias("NUM_TRANS"),
    )
    
)
output_df.show()

DATE orig_stop dest_stop num_trans
2021-01-27T07:00:00.000Z King Cross White Chapell 2
2021-01-27T07:30:00.000Z White Chapell Reaven 1

Resulting Schema

output_df.printSchema()
root
 |-- DATE: timestamp (nullable = true)
 |-- ORIG_STOP: string (nullable = true)
 |-- DEST_STOP: string (nullable = true)
 |-- NUM_TRANS: long (nullable = false)

Setup code for reproducibility

data="""+----+-------------------+-------+---------------+-------------+-------------+
|USER|       DATE        |LINE_ID|      STOP     | TOPOLOGY_ID |TRANSPORT_ID |
+------------------------+-------+---------------+-------------+-------------+
|John|2021-01-27 07:27:34|      7| King Cross    |       171235|       03    |
|John|2021-01-27 07:28:00|     40| White Chapell |       123582|       03    |  
|John|2021-01-27 07:35:30|      4| Reaven        |       171565|       03    |  
|Tom |2021-01-27 07:27:23|      7| King Cross    |       171235|       03    |    
|Tom |2021-01-27 07:28:30|     40| White Chapell |       123582|       03    |                   
+----+-------------------+-------+---------------+-------------+-------------+
"""

rows = [ [ pc.strip() for pc in line.strip().split("|")[1:-1]] for line in data.strip().split("\n")[3:-1]]
headers = [pc.strip() for pc in data.strip().split("\n")[1].split("|")[1:-1]]

from pyspark.sql import functions as F
input_df = sparkSession.createDataFrame(rows,schema=headers)
input_df = input_df.withColumn("DATE",F.col("DATE").cast("TIMESTAMP"))


Let me know if this works for you.

Upvotes: 2

Related Questions