Reputation: 19
I am trying to write a window function that sums the amount of money spent by a user over the last 1 minute, with the limitation of looking only at the last 5 transactions from that user during the calculation.
Given the below data:
test_data = [
("txn 1", 1640995200, "user A", 1),
("txn 2", 1640995201, "user A", 1),
("txn 3", 1640995202, "user A", 1),
("txn 4", 1640995203, "user B", 1),
("txn 5", 1640995204, "user B", 1),
("txn 6", 1640995205, "user A", 1),
("txn 7", 1640995206, "user A", 1),
("txn 8", 1640995207, "user A", 1),
("txn 9", 1640995208, "user A", 1),
("txn 10", 1640995209, "user A", 1),
]
test_schema = StructType([ \
StructField("txn_id",StringType(),True), \
StructField("epoch_time",IntegerType(),True), \
StructField("user_id",StringType(),True), \
StructField("txn_amt", StringType(), True)
])
test_df = spark.createDataFrame(data=test_data,schema=test_schema)
# test_df returns the below df
+------+----------+-------+-------+
|txn_id|epoch_time|user_id|txn_amt|
+------+----------+-------+-------+
| txn 1|1640995200| user A| 1|
| txn 2|1640995201| user A| 1|
| txn 3|1640995202| user A| 1|
| txn 4|1640995203| user B| 1|
| txn 5|1640995204| user B| 1|
| txn 6|1640995205| user A| 1|
| txn 7|1640995206| user A| 1|
| txn 8|1640995207| user A| 1|
| txn 9|1640995208| user A| 1|
|txn 10|1640995209| user A| 1|
+------+----------+-------+-------+
I'm aiming to create the following df
+------+----------+-------+-------+-----------+
|txn_id|epoch_time|user_id|txn_amt|user_sum_1m|
+------+----------+-------+-------+-----------+
| txn 4|1640995203| user B| 1| null|
| txn 5|1640995204| user B| 1| 1.0|
| txn 1|1640995200| user A| 1| null|
| txn 2|1640995201| user A| 1| 1.0|
| txn 3|1640995202| user A| 1| 2.0|
| txn 6|1640995205| user A| 1| 3.0|
| txn 7|1640995206| user A| 1| 4.0|
| txn 8|1640995207| user A| 1| 5.0|
| txn 9|1640995208| user A| 1| 5.0|
|txn 10|1640995209| user A| 1| 5.0|
I've tried the following, attempt at combining a rowsBetween
(to reflect the limitation of 5 most recent txns
per user) and rangeBetween (to capture txns
within time range of interest). However, the rowsBetween
seems to be ignored.
w = (Window ()
.partitionBy(col('user_id'))
.orderBy('epoch_time')
.rowsBetween(-5,-1)
.rangeBetween(-60, -1))
sum_test = test_df.withColumn('user_sum_1m', sum('txn_amt').over(w))
# sum_test returns
+------+----------+-------+-------+-----------+
|txn_id|epoch_time|user_id|txn_amt|user_sum_1m|
+------+----------+-------+-------+-----------+
| txn 4|1640995203| user B| 1| null|
| txn 5|1640995204| user B| 1| 1.0|
| txn 1|1640995200| user A| 1| null|
| txn 2|1640995201| user A| 1| 1.0|
| txn 3|1640995202| user A| 1| 2.0|
| txn 6|1640995205| user A| 1| 3.0|
| txn 7|1640995206| user A| 1| 4.0|
| txn 8|1640995207| user A| 1| 5.0|
| txn 9|1640995208| user A| 1| 6.0| #should not be able to sum > 5 most recent txns
|txn 10|1640995209| user A| 1| 7.0| #should not be able to sum > 5 most recent txns
+------+----------+-------+-------+-----------+
Given that I have specified the window should look at rows -5 to -1, I cannot figure out why additional rows are included in the sum.
In this example, all transactions are within the 1 minute time range, but this would not be the case with my actual data set. Any suggestions to fix this issue?
Thanks in advance!
Upvotes: 1
Views: 2770
Reputation: 32710
Given that I have specified the window should look at rows -5 to -1, I cannot figure out why additional rows are included in the sum.
You can't use rowsBetween
and rangeBetween
at the same time for the window frame. In your code, the window frame is in fact defined as .rangeBetween(-60, -1)
because it's the last one you called so it overrides the .rowsBetween(-5,-1)
. If you remove the ranges between you'll see that it gives the expected output. But it does not guarantee the 5 rows will be in the last 1 minutes.
That said, you can use this trick to achieve what you're looking for. Define your window frame as rangeBetween(-60, -1)
and collect the list of txn_amt
, then slice
the last 5 values from the list and sum up using aggregate
function on arrays:
import pyspark.sql.functions as F
w = Window().partitionBy(F.col('user_id')).orderBy('epoch_time').rangeBetween(-60, -1)
sum_test = test_df.withColumn(
'user_sum_1m',
F.collect_list('txn_amt').over(w)
).withColumn(
'user_sum_1m',
F.expr("""aggregate(
slice(user_sum_1m, greatest(size(user_sum_1m) - 5, 1), 5),
0D,
(acc, x) -> acc + x
)""")
)
sum_test.show()
#+------+----------+-------+-------+-----------+
#|txn_id|epoch_time|user_id|txn_amt|user_sum_1m|
#+------+----------+-------+-------+-----------+
#| txn 1|1640995200| user A| 1| 0.0|
#| txn 2|1640995201| user A| 1| 1.0|
#| txn 3|1640995202| user A| 1| 2.0|
#| txn 6|1640995205| user A| 1| 3.0|
#| txn 7|1640995206| user A| 1| 4.0|
#| txn 8|1640995207| user A| 1| 5.0|
#| txn 9|1640995208| user A| 1| 5.0|
#|txn 10|1640995209| user A| 1| 5.0|
#| txn 4|1640995203| user B| 1| 0.0|
#| txn 5|1640995204| user B| 1| 1.0|
#+------+----------+-------+-------+-----------+
Upvotes: 3