Luiz Viola
Luiz Viola

Reputation: 2436

Find month to date and month to go on a Pyspark dataframe

I have the following dataframe in Spark (using PySpark):

DT_BORD_REF: Timestamp column,
COUNTRY_ALPHA: Country Alpha-3 code,
working_day_flag: if the date is a working day in that country or not

I need to add two fields:

It seems it's an application of a window function, but I can't figure out

+-------------------+-------------+----------------+
|        DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|
+-------------------+-------------+----------------+
|2021-01-01 00:00:00|          FRA|               N|
|2021-01-01 00:00:00|          ITA|               N|
|2021-01-01 00:00:00|          BRA|               N|
|2021-01-02 00:00:00|          BRA|               N|
|2021-01-02 00:00:00|          FRA|               N|
|2021-01-02 00:00:00|          ITA|               N|
|2021-01-03 00:00:00|          ITA|               N|
|2021-01-03 00:00:00|          BRA|               N|
|2021-01-03 00:00:00|          FRA|               N|
|2021-01-04 00:00:00|          BRA|               Y|
|2021-01-04 00:00:00|          FRA|               Y|
|2021-01-04 00:00:00|          ITA|               Y|
|2021-01-05 00:00:00|          FRA|               Y|
|2021-01-05 00:00:00|          BRA|               Y|
|2021-01-05 00:00:00|          ITA|               Y|
|2021-01-06 00:00:00|          ITA|               N|
|2021-01-06 00:00:00|          FRA|               Y|
|2021-01-06 00:00:00|          BRA|               Y|
|2021-01-07 00:00:00|          ITA|               Y|
+-------------------+-------------+----------------+

Upvotes: 2

Views: 1040

Answers (2)

mck
mck

Reputation: 42422

You can do a conditional count using count_if:

df.createOrReplaceTempView('df')

result = spark.sql("""
select *,
    count_if(working_day_flag = 'Y')
        over(partition by country_alpha, trunc(dt_bord_ref, 'month') order by dt_bord_ref)
        month_to_date,
    count_if(working_day_flag = 'Y')
        over(partition by country_alpha, trunc(dt_bord_ref, 'month') order by dt_bord_ref
             rows between 1 following and unbounded following)
        month_to_go    
from df
""")

result.show()
+-------------------+-------------+----------------+-------------+-----------+
|        DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|month_to_date|month_to_go|
+-------------------+-------------+----------------+-------------+-----------+
|2021-01-01 00:00:00|          BRA|               N|            0|          3|
|2021-01-02 00:00:00|          BRA|               N|            0|          3|
|2021-01-03 00:00:00|          BRA|               N|            0|          3|
|2021-01-04 00:00:00|          BRA|               Y|            1|          2|
|2021-01-05 00:00:00|          BRA|               Y|            2|          1|
|2021-01-06 00:00:00|          BRA|               Y|            3|          0|
|2021-01-01 00:00:00|          ITA|               N|            0|          3|
|2021-01-02 00:00:00|          ITA|               N|            0|          3|
|2021-01-03 00:00:00|          ITA|               N|            0|          3|
|2021-01-04 00:00:00|          ITA|               Y|            1|          2|
|2021-01-05 00:00:00|          ITA|               Y|            2|          1|
|2021-01-06 00:00:00|          ITA|               N|            2|          1|
|2021-01-07 00:00:00|          ITA|               Y|            3|          0|
|2021-01-01 00:00:00|          FRA|               N|            0|          3|
|2021-01-02 00:00:00|          FRA|               N|            0|          3|
|2021-01-03 00:00:00|          FRA|               N|            0|          3|
|2021-01-04 00:00:00|          FRA|               Y|            1|          2|
|2021-01-05 00:00:00|          FRA|               Y|            2|          1|
|2021-01-06 00:00:00|          FRA|               Y|            3|          0|
+-------------------+-------------+----------------+-------------+-----------+

If you want a similar solution in Pyspark API:

import pyspark.sql.functions as F
from pyspark.sql.window import Window

result = df.withColumn(
    'month_to_date',
    F.count(
        F.when(F.col('working_day_flag') == 'Y', 1)
    ).over(
        Window.partitionBy('country_alpha', F.trunc('dt_bord_ref', 'month'))
              .orderBy('dt_bord_ref')
    )
).withColumn(
    'month_to_go',
    F.count(
        F.when(F.col('working_day_flag') == 'Y', 1)
    ).over(
        Window.partitionBy('country_alpha', F.trunc('dt_bord_ref', 'month'))
              .orderBy('dt_bord_ref')
              .rowsBetween(1, Window.unboundedFollowing)
    )
)

Upvotes: 1

blackbishop
blackbishop

Reputation: 32720

Use a running sum over Window function. To limit the window to a month and a country, use partition by COUNTRY_ALPHA and DATE_TRUNC(DT_BORD_REF, 'MONTH'). Then using rows between unbounded preceding and current row you can get the sum of worked days until the current date. The same logic applies to get the remaining days in the month by using rows between 1 following and unbounded following.

To filter only days with working_day_flag = 'Y', use conditional sum with case/when.

Here's a working example with the sample data you provided in your question:

df.createOrReplaceTempView("df")

sql_query = """
SELECT
  *,
  SUM(CASE
    WHEN BOOLEAN(working_day_flag) THEN 1
    ELSE 0
  END) OVER (
  PARTITION BY COUNTRY_ALPHA, DATE_TRUNC('MONTH', DT_BORD_REF)
  ORDER BY DT_BORD_REF ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
  ) AS month_to_date,

  COALESCE(SUM(CASE
    WHEN BOOLEAN(working_day_flag) THEN 1
    ELSE 0
  END) OVER (
  PARTITION BY COUNTRY_ALPHA, DATE_TRUNC('MONTH', DT_BORD_REF)
  ORDER BY DT_BORD_REF ROWS BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING
  ), 0) AS month_to_go

FROM df
""" 

spark.sql(sql_query).show()

#+-------------------+-------------+----------------+-------------+-----------+
#|        DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|month_to_date|month_to_go|
#+-------------------+-------------+----------------+-------------+-----------+
#|2021-01-01 00:00:00|          BRA|               N|            0|          3|
#|2021-01-02 00:00:00|          BRA|               N|            0|          3|
#|2021-01-03 00:00:00|          BRA|               N|            0|          3|
#|2021-01-04 00:00:00|          BRA|               Y|            1|          2|
#|2021-01-05 00:00:00|          BRA|               Y|            2|          1|
#|2021-01-06 00:00:00|          BRA|               Y|            3|          0|
#|2021-01-01 00:00:00|          FRA|               N|            0|          3|
#|2021-01-02 00:00:00|          FRA|               N|            0|          3|
#|2021-01-03 00:00:00|          FRA|               N|            0|          3|
#|2021-01-04 00:00:00|          FRA|               Y|            1|          2|
#|2021-01-05 00:00:00|          FRA|               Y|            2|          1|
#|2021-01-06 00:00:00|          FRA|               Y|            3|          0|
#|2021-01-01 00:00:00|          ITA|               N|            0|          3|
#|2021-01-02 00:00:00|          ITA|               N|            0|          3|
#|2021-01-03 00:00:00|          ITA|               N|            0|          3|
#|2021-01-04 00:00:00|          ITA|               Y|            1|          2|
#|2021-01-05 00:00:00|          ITA|               Y|            2|          1|
#|2021-01-06 00:00:00|          ITA|               N|            2|          1|
#|2021-01-07 00:00:00|          ITA|               Y|            3|          0|

Upvotes: 1

Related Questions