Reputation: 23
I have a Dataframe as follows:
|id |date_1 |date_2 |
+---+----------+-----------+
|0 |2017-01-21|2017-04-01 |
|1 |2017-01-22|2017-04-24 |
|2 |2017-02-23|2017-04-30 |
|3 |2017-02-27|2017-04-30 |
|4 |2017-04-23|2017-05-27 |
|5 |2017-04-29|2017-06-30 |
|6 |2017-06-13|2017-07-05 |
|7 |2017-06-13|2017-07-18 |
|8 |2017-06-16|2017-07-19 |
|9 |2017-07-09|2017-08-02 |
|10 |2017-07-18|2017-08-07 |
|11 |2017-07-28|2017-08-11 |
|12 |2017-07-28|2017-08-13 |
|13 |2017-08-04|2017-08-13 |
|14 |2017-08-13|2017-08-13 |
|15 |2017-08-13|2017-08-13 |
|16 |2017-08-13|2017-08-25 |
|17 |2017-08-13|2017-09-10 |
|18 |2017-08-31|2017-09-21 |
|19 |2017-10-03|2017-09-22 |
+---+----------+-----------+
I know there are many ways to do what I am asking using different pyspark APIs, however I would like to use the Window
API to accomplish the following.
It is essentially a double for loop in any other situation.
For every date in date_1
, look at every date in date_2
that is in the same or subsequent rows, and count the number of occurrences where the difference falls within a week, month, ..., (the timeframe is irrelevant, but for consistency sake, let's go with week). Use these results to add another column with the count.
The challenge is getting the right Window
(s) combination to consider both date columns.
Upvotes: 2
Views: 1128
Reputation: 4990
If I understood the question's author correctly for each row X in the dataframe we want to go over all rows starting from that (ordered by e.g. id
) and for each such row Y compare X.date_1 with Y.date_2. Number of rows Y for which difference between X.date_1 and Y.date_2 is less than e.g. 1 week should be added as a column to row X (e.g. X.result).
Unfortunately windowing functions are not providing a functionality to access X.date_1
inside window functions and thus it is impossible to achieve using windowing functions.
This seems to be very similar to this question where author tries to do a similar thing for Postgres.
But then there is a way to actually do it though through a bit of cheating - i.e. to "materialize" window frame for each Row in an array and then perform the needed operations. Not sure if in your view this would count, but that is the only way Window API could be used to resolve the problem. A possible solution could look like this (assuming we want to count number of rows Y going not earlier than X w.r.t id
with Y.date_2
between X.date_1
and X.date_1 + 7 days
):
import datetime
rawdata = [l.strip('|').replace('|', ' ').split() for l in '''|0 |2017-01-21|2017-04-01 |
|1 |2017-01-22|2017-04-24 |
|2 |2017-02-23|2017-04-30 |
|3 |2017-02-27|2017-04-30 |
|4 |2017-04-23|2017-05-27 |
|5 |2017-04-29|2017-06-30 |
|6 |2017-06-13|2017-07-05 |
|7 |2017-06-13|2017-07-18 |
|8 |2017-06-16|2017-07-19 |
|9 |2017-07-09|2017-08-02 |
|10 |2017-07-18|2017-08-07 |
|11 |2017-07-28|2017-08-11 |
|12 |2017-07-28|2017-08-13 |
|13 |2017-08-04|2017-08-13 |
|14 |2017-08-13|2017-08-13 |
|15 |2017-08-13|2017-08-13 |
|16 |2017-08-13|2017-08-25 |
|17 |2017-08-13|2017-09-10 |
|18 |2017-08-31|2017-09-21 |
|19 |2017-10-03|2017-09-22 |'''.split('\n')]
data = [(int(d[0]), datetime.date.fromisoformat(d[1]), datetime.date.fromisoformat(d[2])) for d in rawdata]
df = spark.createDataFrame(data, schema='id: bigint, date_1: Date, date_2: Date')
from pyspark.sql.window import Window
import pyspark.sql.functions as func
window_spec = Window.orderBy('id').rowsBetween(Window.currentRow, Window.unboundedFollowing)
new_df = df.withColumn('materialized_frame_date_2', func.collect_list(df['date_2']).over(window_spec)) \
.withColumn('result', func.expr('size(filter(materialized_frame_date_2, x -> datediff(x, date_1) BETWEEN 0 AND 7))')) \
.drop('materialized_frame_date_2')
new_df.show()
The result:
+---+----------+----------+------+
| id| date_1| date_2|result|
+---+----------+----------+------+
| 0|2017-01-21|2017-04-01| 0|
| 1|2017-01-22|2017-04-24| 0|
| 2|2017-02-23|2017-04-30| 0|
| 3|2017-02-27|2017-04-30| 0|
| 4|2017-04-23|2017-05-27| 0|
| 5|2017-04-29|2017-06-30| 0|
| 6|2017-06-13|2017-07-05| 0|
| 7|2017-06-13|2017-07-18| 0|
| 8|2017-06-16|2017-07-19| 0|
| 9|2017-07-09|2017-08-02| 0|
| 10|2017-07-18|2017-08-07| 0|
| 11|2017-07-28|2017-08-11| 0|
| 12|2017-07-28|2017-08-13| 0|
| 13|2017-08-04|2017-08-13| 0|
| 14|2017-08-13|2017-08-13| 2|
| 15|2017-08-13|2017-08-13| 1|
| 16|2017-08-13|2017-08-25| 0|
| 17|2017-08-13|2017-09-10| 0|
| 18|2017-08-31|2017-09-21| 0|
| 19|2017-10-03|2017-09-22| 0|
+---+----------+----------+------+
Upvotes: 2
Reputation: 6323
Perhaps this is helpful-
val data =
"""
|id |date_1 |date_2
|0 |2017-01-21|2017-04-01
|1 |2017-01-22|2017-04-24
|2 |2017-02-23|2017-04-30
|3 |2017-02-27|2017-04-30
|4 |2017-04-23|2017-05-27
|5 |2017-04-29|2017-06-30
|6 |2017-06-13|2017-07-05
|7 |2017-06-13|2017-07-18
|8 |2017-06-16|2017-07-19
|9 |2017-07-09|2017-08-02
|10 |2017-07-18|2017-08-07
|11 |2017-07-28|2017-08-11
|12 |2017-07-28|2017-08-13
|13 |2017-08-04|2017-08-13
|14 |2017-08-13|2017-08-13
|15 |2017-08-13|2017-08-13
|16 |2017-08-13|2017-08-25
|17 |2017-08-13|2017-09-10
|18 |2017-08-31|2017-09-21
|19 |2017-10-03|2017-09-22
""".stripMargin
val stringDS = data.split(System.lineSeparator())
.map(_.split("\\|").map(_.replaceAll("""^[ \t]+|[ \t]+$""", "")).mkString(","))
.toSeq.toDS()
val df = spark.read
.option("sep", ",")
.option("inferSchema", "true")
.option("header", "true")
.option("nullValue", "null")
.csv(stringDS)
df.show(false)
df.printSchema()
/**
* +---+-------------------+-------------------+
* |id |date_1 |date_2 |
* +---+-------------------+-------------------+
* |0 |2017-01-21 00:00:00|2017-04-01 00:00:00|
* |1 |2017-01-22 00:00:00|2017-04-24 00:00:00|
* |2 |2017-02-23 00:00:00|2017-04-30 00:00:00|
* |3 |2017-02-27 00:00:00|2017-04-30 00:00:00|
* |4 |2017-04-23 00:00:00|2017-05-27 00:00:00|
* |5 |2017-04-29 00:00:00|2017-06-30 00:00:00|
* |6 |2017-06-13 00:00:00|2017-07-05 00:00:00|
* |7 |2017-06-13 00:00:00|2017-07-18 00:00:00|
* |8 |2017-06-16 00:00:00|2017-07-19 00:00:00|
* |9 |2017-07-09 00:00:00|2017-08-02 00:00:00|
* |10 |2017-07-18 00:00:00|2017-08-07 00:00:00|
* |11 |2017-07-28 00:00:00|2017-08-11 00:00:00|
* |12 |2017-07-28 00:00:00|2017-08-13 00:00:00|
* |13 |2017-08-04 00:00:00|2017-08-13 00:00:00|
* |14 |2017-08-13 00:00:00|2017-08-13 00:00:00|
* |15 |2017-08-13 00:00:00|2017-08-13 00:00:00|
* |16 |2017-08-13 00:00:00|2017-08-25 00:00:00|
* |17 |2017-08-13 00:00:00|2017-09-10 00:00:00|
* |18 |2017-08-31 00:00:00|2017-09-21 00:00:00|
* |19 |2017-10-03 00:00:00|2017-09-22 00:00:00|
* +---+-------------------+-------------------+
*
* root
* |-- id: integer (nullable = true)
* |-- date_1: timestamp (nullable = true)
* |-- date_2: timestamp (nullable = true)
*/
// week
val weekDiff = 7
val w = Window.orderBy("id", "date_1", "date_2")
.rangeBetween(Window.currentRow, Window.unboundedFollowing)
df.withColumn("count", sum(
when(datediff($"date_1", $"date_2") <= weekDiff, 1).otherwise(0)
).over(w))
.orderBy("id")
.show(false)
/**
* +---+-------------------+-------------------+-----+
* |id |date_1 |date_2 |count|
* +---+-------------------+-------------------+-----+
* |0 |2017-01-21 00:00:00|2017-04-01 00:00:00|19 |
* |1 |2017-01-22 00:00:00|2017-04-24 00:00:00|18 |
* |2 |2017-02-23 00:00:00|2017-04-30 00:00:00|17 |
* |3 |2017-02-27 00:00:00|2017-04-30 00:00:00|16 |
* |4 |2017-04-23 00:00:00|2017-05-27 00:00:00|15 |
* |5 |2017-04-29 00:00:00|2017-06-30 00:00:00|14 |
* |6 |2017-06-13 00:00:00|2017-07-05 00:00:00|13 |
* |7 |2017-06-13 00:00:00|2017-07-18 00:00:00|12 |
* |8 |2017-06-16 00:00:00|2017-07-19 00:00:00|11 |
* |9 |2017-07-09 00:00:00|2017-08-02 00:00:00|10 |
* |10 |2017-07-18 00:00:00|2017-08-07 00:00:00|9 |
* |11 |2017-07-28 00:00:00|2017-08-11 00:00:00|8 |
* |12 |2017-07-28 00:00:00|2017-08-13 00:00:00|7 |
* |13 |2017-08-04 00:00:00|2017-08-13 00:00:00|6 |
* |14 |2017-08-13 00:00:00|2017-08-13 00:00:00|5 |
* |15 |2017-08-13 00:00:00|2017-08-13 00:00:00|4 |
* |16 |2017-08-13 00:00:00|2017-08-25 00:00:00|3 |
* |17 |2017-08-13 00:00:00|2017-09-10 00:00:00|2 |
* |18 |2017-08-31 00:00:00|2017-09-21 00:00:00|1 |
* |19 |2017-10-03 00:00:00|2017-09-22 00:00:00|0 |
* +---+-------------------+-------------------+-----+
*/
Upvotes: 2