ryrodrig
ryrodrig

Reputation: 19

PySpark Window using rangeBetween and rowsBetween together

I am trying to write a window function that sums the amount of money spent by a user over the last 1 minute, with the limitation of looking only at the last 5 transactions from that user during the calculation.

Given the below data:

test_data =  [
    ("txn 1", 1640995200, "user A", 1),
    ("txn 2", 1640995201, "user A", 1),
    ("txn 3", 1640995202, "user A", 1),
    ("txn 4", 1640995203, "user B", 1),
    ("txn 5", 1640995204, "user B", 1),
    ("txn 6", 1640995205, "user A", 1),
    ("txn 7", 1640995206, "user A", 1),
    ("txn 8", 1640995207, "user A", 1),
    ("txn 9", 1640995208, "user A", 1),
    ("txn 10", 1640995209, "user A", 1),
  ]

test_schema = StructType([ \
    StructField("txn_id",StringType(),True), \
    StructField("epoch_time",IntegerType(),True), \
    StructField("user_id",StringType(),True), \
    StructField("txn_amt", StringType(), True)     
  ])

test_df = spark.createDataFrame(data=test_data,schema=test_schema)

# test_df returns the below df
+------+----------+-------+-------+
|txn_id|epoch_time|user_id|txn_amt|
+------+----------+-------+-------+
| txn 1|1640995200| user A|      1|
| txn 2|1640995201| user A|      1|
| txn 3|1640995202| user A|      1|
| txn 4|1640995203| user B|      1|
| txn 5|1640995204| user B|      1|
| txn 6|1640995205| user A|      1|
| txn 7|1640995206| user A|      1|
| txn 8|1640995207| user A|      1|
| txn 9|1640995208| user A|      1|
|txn 10|1640995209| user A|      1|
+------+----------+-------+-------+

I'm aiming to create the following df

+------+----------+-------+-------+-----------+
|txn_id|epoch_time|user_id|txn_amt|user_sum_1m|
+------+----------+-------+-------+-----------+
| txn 4|1640995203| user B|      1|       null|
| txn 5|1640995204| user B|      1|        1.0|
| txn 1|1640995200| user A|      1|       null|
| txn 2|1640995201| user A|      1|        1.0|
| txn 3|1640995202| user A|      1|        2.0|
| txn 6|1640995205| user A|      1|        3.0|
| txn 7|1640995206| user A|      1|        4.0|
| txn 8|1640995207| user A|      1|        5.0|
| txn 9|1640995208| user A|      1|        5.0|
|txn 10|1640995209| user A|      1|        5.0|

I've tried the following, attempt at combining a rowsBetween (to reflect the limitation of 5 most recent txns per user) and rangeBetween (to capture txns within time range of interest). However, the rowsBetween seems to be ignored.

w = (Window ()
        .partitionBy(col('user_id'))
        .orderBy('epoch_time')
        .rowsBetween(-5,-1)
        .rangeBetween(-60, -1))

sum_test = test_df.withColumn('user_sum_1m', sum('txn_amt').over(w))
# sum_test returns 
+------+----------+-------+-------+-----------+
|txn_id|epoch_time|user_id|txn_amt|user_sum_1m|
+------+----------+-------+-------+-----------+
| txn 4|1640995203| user B|      1|       null|
| txn 5|1640995204| user B|      1|        1.0|
| txn 1|1640995200| user A|      1|       null|
| txn 2|1640995201| user A|      1|        1.0|
| txn 3|1640995202| user A|      1|        2.0|
| txn 6|1640995205| user A|      1|        3.0|
| txn 7|1640995206| user A|      1|        4.0|
| txn 8|1640995207| user A|      1|        5.0|
| txn 9|1640995208| user A|      1|        6.0| #should not be able to sum > 5 most recent txns
|txn 10|1640995209| user A|      1|        7.0| #should not be able to sum > 5 most recent txns
+------+----------+-------+-------+-----------+

Given that I have specified the window should look at rows -5 to -1, I cannot figure out why additional rows are included in the sum.

In this example, all transactions are within the 1 minute time range, but this would not be the case with my actual data set. Any suggestions to fix this issue?

Thanks in advance!

Upvotes: 1

Views: 2770

Answers (1)

blackbishop
blackbishop

Reputation: 32710

Given that I have specified the window should look at rows -5 to -1, I cannot figure out why additional rows are included in the sum.

You can't use rowsBetween and rangeBetween at the same time for the window frame. In your code, the window frame is in fact defined as .rangeBetween(-60, -1) because it's the last one you called so it overrides the .rowsBetween(-5,-1). If you remove the ranges between you'll see that it gives the expected output. But it does not guarantee the 5 rows will be in the last 1 minutes.


That said, you can use this trick to achieve what you're looking for. Define your window frame as rangeBetween(-60, -1) and collect the list of txn_amt, then slice the last 5 values from the list and sum up using aggregate function on arrays:

import pyspark.sql.functions as F

w = Window().partitionBy(F.col('user_id')).orderBy('epoch_time').rangeBetween(-60, -1)

sum_test = test_df.withColumn(
    'user_sum_1m',
    F.collect_list('txn_amt').over(w)
).withColumn(
    'user_sum_1m',
    F.expr("""aggregate(
                slice(user_sum_1m, greatest(size(user_sum_1m) - 5, 1), 5), 
                0D, 
                (acc, x) -> acc + x
            )""")
)

sum_test.show()
#+------+----------+-------+-------+-----------+
#|txn_id|epoch_time|user_id|txn_amt|user_sum_1m|
#+------+----------+-------+-------+-----------+
#| txn 1|1640995200| user A|      1|        0.0|
#| txn 2|1640995201| user A|      1|        1.0|
#| txn 3|1640995202| user A|      1|        2.0|
#| txn 6|1640995205| user A|      1|        3.0|
#| txn 7|1640995206| user A|      1|        4.0|
#| txn 8|1640995207| user A|      1|        5.0|
#| txn 9|1640995208| user A|      1|        5.0|
#|txn 10|1640995209| user A|      1|        5.0|
#| txn 4|1640995203| user B|      1|        0.0|
#| txn 5|1640995204| user B|      1|        1.0|
#+------+----------+-------+-------+-----------+

Upvotes: 3

Related Questions