CaptainKirk
CaptainKirk

Reputation: 23

PySpark Multiple Columns Using Windows

I have a Dataframe as follows:

|id |date_1    |date_2     |
+---+----------+-----------+
|0  |2017-01-21|2017-04-01 |
|1  |2017-01-22|2017-04-24 |
|2  |2017-02-23|2017-04-30 |
|3  |2017-02-27|2017-04-30 |
|4  |2017-04-23|2017-05-27 |
|5  |2017-04-29|2017-06-30 |
|6  |2017-06-13|2017-07-05 |
|7  |2017-06-13|2017-07-18 |
|8  |2017-06-16|2017-07-19 |
|9  |2017-07-09|2017-08-02 |
|10 |2017-07-18|2017-08-07 |
|11 |2017-07-28|2017-08-11 |
|12 |2017-07-28|2017-08-13 |
|13 |2017-08-04|2017-08-13 |
|14 |2017-08-13|2017-08-13 |
|15 |2017-08-13|2017-08-13 |
|16 |2017-08-13|2017-08-25 |
|17 |2017-08-13|2017-09-10 |
|18 |2017-08-31|2017-09-21 |
|19 |2017-10-03|2017-09-22 |
+---+----------+-----------+

I know there are many ways to do what I am asking using different pyspark APIs, however I would like to use the Window API to accomplish the following.

It is essentially a double for loop in any other situation.

For every date in date_1, look at every date in date_2 that is in the same or subsequent rows, and count the number of occurrences where the difference falls within a week, month, ..., (the timeframe is irrelevant, but for consistency sake, let's go with week). Use these results to add another column with the count.

The challenge is getting the right Window(s) combination to consider both date columns.

Upvotes: 2

Views: 1128

Answers (2)

Alexander Pivovarov
Alexander Pivovarov

Reputation: 4990

If I understood the question's author correctly for each row X in the dataframe we want to go over all rows starting from that (ordered by e.g. id) and for each such row Y compare X.date_1 with Y.date_2. Number of rows Y for which difference between X.date_1 and Y.date_2 is less than e.g. 1 week should be added as a column to row X (e.g. X.result).

Unfortunately windowing functions are not providing a functionality to access X.date_1 inside window functions and thus it is impossible to achieve using windowing functions.

This seems to be very similar to this question where author tries to do a similar thing for Postgres.

But then there is a way to actually do it though through a bit of cheating - i.e. to "materialize" window frame for each Row in an array and then perform the needed operations. Not sure if in your view this would count, but that is the only way Window API could be used to resolve the problem. A possible solution could look like this (assuming we want to count number of rows Y going not earlier than X w.r.t id with Y.date_2 between X.date_1 and X.date_1 + 7 days):

import datetime
rawdata = [l.strip('|').replace('|', ' ').split() for l in '''|0  |2017-01-21|2017-04-01 |
|1  |2017-01-22|2017-04-24 |
|2  |2017-02-23|2017-04-30 |
|3  |2017-02-27|2017-04-30 |
|4  |2017-04-23|2017-05-27 |
|5  |2017-04-29|2017-06-30 |
|6  |2017-06-13|2017-07-05 |
|7  |2017-06-13|2017-07-18 |
|8  |2017-06-16|2017-07-19 |
|9  |2017-07-09|2017-08-02 |
|10 |2017-07-18|2017-08-07 |
|11 |2017-07-28|2017-08-11 |
|12 |2017-07-28|2017-08-13 |
|13 |2017-08-04|2017-08-13 |
|14 |2017-08-13|2017-08-13 |
|15 |2017-08-13|2017-08-13 |
|16 |2017-08-13|2017-08-25 |
|17 |2017-08-13|2017-09-10 |
|18 |2017-08-31|2017-09-21 |
|19 |2017-10-03|2017-09-22 |'''.split('\n')]
data = [(int(d[0]), datetime.date.fromisoformat(d[1]), datetime.date.fromisoformat(d[2])) for d in rawdata]
df = spark.createDataFrame(data, schema='id: bigint, date_1: Date, date_2: Date')

from pyspark.sql.window import Window
import pyspark.sql.functions as func
window_spec = Window.orderBy('id').rowsBetween(Window.currentRow, Window.unboundedFollowing)
new_df = df.withColumn('materialized_frame_date_2', func.collect_list(df['date_2']).over(window_spec)) \
  .withColumn('result', func.expr('size(filter(materialized_frame_date_2, x -> datediff(x, date_1) BETWEEN 0 AND 7))')) \
  .drop('materialized_frame_date_2')
new_df.show()

The result:

+---+----------+----------+------+
| id|    date_1|    date_2|result|
+---+----------+----------+------+
|  0|2017-01-21|2017-04-01|     0|
|  1|2017-01-22|2017-04-24|     0|
|  2|2017-02-23|2017-04-30|     0|
|  3|2017-02-27|2017-04-30|     0|
|  4|2017-04-23|2017-05-27|     0|
|  5|2017-04-29|2017-06-30|     0|
|  6|2017-06-13|2017-07-05|     0|
|  7|2017-06-13|2017-07-18|     0|
|  8|2017-06-16|2017-07-19|     0|
|  9|2017-07-09|2017-08-02|     0|
| 10|2017-07-18|2017-08-07|     0|
| 11|2017-07-28|2017-08-11|     0|
| 12|2017-07-28|2017-08-13|     0|
| 13|2017-08-04|2017-08-13|     0|
| 14|2017-08-13|2017-08-13|     2|
| 15|2017-08-13|2017-08-13|     1|
| 16|2017-08-13|2017-08-25|     0|
| 17|2017-08-13|2017-09-10|     0|
| 18|2017-08-31|2017-09-21|     0|
| 19|2017-10-03|2017-09-22|     0|
+---+----------+----------+------+

Upvotes: 2

Som
Som

Reputation: 6323

Perhaps this is helpful-

Load the test data provided

 val data =
      """
        |id |date_1    |date_2
        |0  |2017-01-21|2017-04-01
        |1  |2017-01-22|2017-04-24
        |2  |2017-02-23|2017-04-30
        |3  |2017-02-27|2017-04-30
        |4  |2017-04-23|2017-05-27
        |5  |2017-04-29|2017-06-30
        |6  |2017-06-13|2017-07-05
        |7  |2017-06-13|2017-07-18
        |8  |2017-06-16|2017-07-19
        |9  |2017-07-09|2017-08-02
        |10 |2017-07-18|2017-08-07
        |11 |2017-07-28|2017-08-11
        |12 |2017-07-28|2017-08-13
        |13 |2017-08-04|2017-08-13
        |14 |2017-08-13|2017-08-13
        |15 |2017-08-13|2017-08-13
        |16 |2017-08-13|2017-08-25
        |17 |2017-08-13|2017-09-10
        |18 |2017-08-31|2017-09-21
        |19 |2017-10-03|2017-09-22
      """.stripMargin

    val stringDS = data.split(System.lineSeparator())
      .map(_.split("\\|").map(_.replaceAll("""^[ \t]+|[ \t]+$""", "")).mkString(","))
      .toSeq.toDS()
    val df = spark.read
      .option("sep", ",")
      .option("inferSchema", "true")
      .option("header", "true")
      .option("nullValue", "null")
      .csv(stringDS)

    df.show(false)
    df.printSchema()
    /**
      * +---+-------------------+-------------------+
      * |id |date_1             |date_2             |
      * +---+-------------------+-------------------+
      * |0  |2017-01-21 00:00:00|2017-04-01 00:00:00|
      * |1  |2017-01-22 00:00:00|2017-04-24 00:00:00|
      * |2  |2017-02-23 00:00:00|2017-04-30 00:00:00|
      * |3  |2017-02-27 00:00:00|2017-04-30 00:00:00|
      * |4  |2017-04-23 00:00:00|2017-05-27 00:00:00|
      * |5  |2017-04-29 00:00:00|2017-06-30 00:00:00|
      * |6  |2017-06-13 00:00:00|2017-07-05 00:00:00|
      * |7  |2017-06-13 00:00:00|2017-07-18 00:00:00|
      * |8  |2017-06-16 00:00:00|2017-07-19 00:00:00|
      * |9  |2017-07-09 00:00:00|2017-08-02 00:00:00|
      * |10 |2017-07-18 00:00:00|2017-08-07 00:00:00|
      * |11 |2017-07-28 00:00:00|2017-08-11 00:00:00|
      * |12 |2017-07-28 00:00:00|2017-08-13 00:00:00|
      * |13 |2017-08-04 00:00:00|2017-08-13 00:00:00|
      * |14 |2017-08-13 00:00:00|2017-08-13 00:00:00|
      * |15 |2017-08-13 00:00:00|2017-08-13 00:00:00|
      * |16 |2017-08-13 00:00:00|2017-08-25 00:00:00|
      * |17 |2017-08-13 00:00:00|2017-09-10 00:00:00|
      * |18 |2017-08-31 00:00:00|2017-09-21 00:00:00|
      * |19 |2017-10-03 00:00:00|2017-09-22 00:00:00|
      * +---+-------------------+-------------------+
      *
      * root
      * |-- id: integer (nullable = true)
      * |-- date_1: timestamp (nullable = true)
      * |-- date_2: timestamp (nullable = true)
      */

count the number of occurrences where the difference(date_1-date_2) falls within a week

    // week
    val weekDiff = 7
    val w = Window.orderBy("id", "date_1", "date_2")
      .rangeBetween(Window.currentRow, Window.unboundedFollowing)
    df.withColumn("count", sum(
      when(datediff($"date_1", $"date_2") <= weekDiff, 1).otherwise(0)
    ).over(w))
      .orderBy("id")
      .show(false)

    /**
      * +---+-------------------+-------------------+-----+
      * |id |date_1             |date_2             |count|
      * +---+-------------------+-------------------+-----+
      * |0  |2017-01-21 00:00:00|2017-04-01 00:00:00|19   |
      * |1  |2017-01-22 00:00:00|2017-04-24 00:00:00|18   |
      * |2  |2017-02-23 00:00:00|2017-04-30 00:00:00|17   |
      * |3  |2017-02-27 00:00:00|2017-04-30 00:00:00|16   |
      * |4  |2017-04-23 00:00:00|2017-05-27 00:00:00|15   |
      * |5  |2017-04-29 00:00:00|2017-06-30 00:00:00|14   |
      * |6  |2017-06-13 00:00:00|2017-07-05 00:00:00|13   |
      * |7  |2017-06-13 00:00:00|2017-07-18 00:00:00|12   |
      * |8  |2017-06-16 00:00:00|2017-07-19 00:00:00|11   |
      * |9  |2017-07-09 00:00:00|2017-08-02 00:00:00|10   |
      * |10 |2017-07-18 00:00:00|2017-08-07 00:00:00|9    |
      * |11 |2017-07-28 00:00:00|2017-08-11 00:00:00|8    |
      * |12 |2017-07-28 00:00:00|2017-08-13 00:00:00|7    |
      * |13 |2017-08-04 00:00:00|2017-08-13 00:00:00|6    |
      * |14 |2017-08-13 00:00:00|2017-08-13 00:00:00|5    |
      * |15 |2017-08-13 00:00:00|2017-08-13 00:00:00|4    |
      * |16 |2017-08-13 00:00:00|2017-08-25 00:00:00|3    |
      * |17 |2017-08-13 00:00:00|2017-09-10 00:00:00|2    |
      * |18 |2017-08-31 00:00:00|2017-09-21 00:00:00|1    |
      * |19 |2017-10-03 00:00:00|2017-09-22 00:00:00|0    |
      * +---+-------------------+-------------------+-----+
      */

Upvotes: 2

Related Questions