Reputation: 2436
I have the following dataframe in Spark (using PySpark):
DT_BORD_REF
: Timestamp column,
COUNTRY_ALPHA
: Country Alpha-3 code,
working_day_flag
: if the date is a working day in that country or not
I need to add two fields:
It seems it's an application of a window function, but I can't figure out
+-------------------+-------------+----------------+
| DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|
+-------------------+-------------+----------------+
|2021-01-01 00:00:00| FRA| N|
|2021-01-01 00:00:00| ITA| N|
|2021-01-01 00:00:00| BRA| N|
|2021-01-02 00:00:00| BRA| N|
|2021-01-02 00:00:00| FRA| N|
|2021-01-02 00:00:00| ITA| N|
|2021-01-03 00:00:00| ITA| N|
|2021-01-03 00:00:00| BRA| N|
|2021-01-03 00:00:00| FRA| N|
|2021-01-04 00:00:00| BRA| Y|
|2021-01-04 00:00:00| FRA| Y|
|2021-01-04 00:00:00| ITA| Y|
|2021-01-05 00:00:00| FRA| Y|
|2021-01-05 00:00:00| BRA| Y|
|2021-01-05 00:00:00| ITA| Y|
|2021-01-06 00:00:00| ITA| N|
|2021-01-06 00:00:00| FRA| Y|
|2021-01-06 00:00:00| BRA| Y|
|2021-01-07 00:00:00| ITA| Y|
+-------------------+-------------+----------------+
Upvotes: 2
Views: 1040
Reputation: 42422
You can do a conditional count using count_if
:
df.createOrReplaceTempView('df')
result = spark.sql("""
select *,
count_if(working_day_flag = 'Y')
over(partition by country_alpha, trunc(dt_bord_ref, 'month') order by dt_bord_ref)
month_to_date,
count_if(working_day_flag = 'Y')
over(partition by country_alpha, trunc(dt_bord_ref, 'month') order by dt_bord_ref
rows between 1 following and unbounded following)
month_to_go
from df
""")
result.show()
+-------------------+-------------+----------------+-------------+-----------+
| DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|month_to_date|month_to_go|
+-------------------+-------------+----------------+-------------+-----------+
|2021-01-01 00:00:00| BRA| N| 0| 3|
|2021-01-02 00:00:00| BRA| N| 0| 3|
|2021-01-03 00:00:00| BRA| N| 0| 3|
|2021-01-04 00:00:00| BRA| Y| 1| 2|
|2021-01-05 00:00:00| BRA| Y| 2| 1|
|2021-01-06 00:00:00| BRA| Y| 3| 0|
|2021-01-01 00:00:00| ITA| N| 0| 3|
|2021-01-02 00:00:00| ITA| N| 0| 3|
|2021-01-03 00:00:00| ITA| N| 0| 3|
|2021-01-04 00:00:00| ITA| Y| 1| 2|
|2021-01-05 00:00:00| ITA| Y| 2| 1|
|2021-01-06 00:00:00| ITA| N| 2| 1|
|2021-01-07 00:00:00| ITA| Y| 3| 0|
|2021-01-01 00:00:00| FRA| N| 0| 3|
|2021-01-02 00:00:00| FRA| N| 0| 3|
|2021-01-03 00:00:00| FRA| N| 0| 3|
|2021-01-04 00:00:00| FRA| Y| 1| 2|
|2021-01-05 00:00:00| FRA| Y| 2| 1|
|2021-01-06 00:00:00| FRA| Y| 3| 0|
+-------------------+-------------+----------------+-------------+-----------+
If you want a similar solution in Pyspark API:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
result = df.withColumn(
'month_to_date',
F.count(
F.when(F.col('working_day_flag') == 'Y', 1)
).over(
Window.partitionBy('country_alpha', F.trunc('dt_bord_ref', 'month'))
.orderBy('dt_bord_ref')
)
).withColumn(
'month_to_go',
F.count(
F.when(F.col('working_day_flag') == 'Y', 1)
).over(
Window.partitionBy('country_alpha', F.trunc('dt_bord_ref', 'month'))
.orderBy('dt_bord_ref')
.rowsBetween(1, Window.unboundedFollowing)
)
)
Upvotes: 1
Reputation: 32720
Use a running sum over Window function. To limit the window to a month and a country, use partition by COUNTRY_ALPHA
and DATE_TRUNC(DT_BORD_REF, 'MONTH')
. Then using rows between unbounded preceding and current row you can get the sum of worked days until the current date. The same logic applies to get the remaining days in the month by using rows between 1 following and unbounded following.
To filter only days with working_day_flag = 'Y'
, use conditional sum with case/when
.
Here's a working example with the sample data you provided in your question:
df.createOrReplaceTempView("df")
sql_query = """
SELECT
*,
SUM(CASE
WHEN BOOLEAN(working_day_flag) THEN 1
ELSE 0
END) OVER (
PARTITION BY COUNTRY_ALPHA, DATE_TRUNC('MONTH', DT_BORD_REF)
ORDER BY DT_BORD_REF ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
) AS month_to_date,
COALESCE(SUM(CASE
WHEN BOOLEAN(working_day_flag) THEN 1
ELSE 0
END) OVER (
PARTITION BY COUNTRY_ALPHA, DATE_TRUNC('MONTH', DT_BORD_REF)
ORDER BY DT_BORD_REF ROWS BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING
), 0) AS month_to_go
FROM df
"""
spark.sql(sql_query).show()
#+-------------------+-------------+----------------+-------------+-----------+
#| DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|month_to_date|month_to_go|
#+-------------------+-------------+----------------+-------------+-----------+
#|2021-01-01 00:00:00| BRA| N| 0| 3|
#|2021-01-02 00:00:00| BRA| N| 0| 3|
#|2021-01-03 00:00:00| BRA| N| 0| 3|
#|2021-01-04 00:00:00| BRA| Y| 1| 2|
#|2021-01-05 00:00:00| BRA| Y| 2| 1|
#|2021-01-06 00:00:00| BRA| Y| 3| 0|
#|2021-01-01 00:00:00| FRA| N| 0| 3|
#|2021-01-02 00:00:00| FRA| N| 0| 3|
#|2021-01-03 00:00:00| FRA| N| 0| 3|
#|2021-01-04 00:00:00| FRA| Y| 1| 2|
#|2021-01-05 00:00:00| FRA| Y| 2| 1|
#|2021-01-06 00:00:00| FRA| Y| 3| 0|
#|2021-01-01 00:00:00| ITA| N| 0| 3|
#|2021-01-02 00:00:00| ITA| N| 0| 3|
#|2021-01-03 00:00:00| ITA| N| 0| 3|
#|2021-01-04 00:00:00| ITA| Y| 1| 2|
#|2021-01-05 00:00:00| ITA| Y| 2| 1|
#|2021-01-06 00:00:00| ITA| N| 2| 1|
#|2021-01-07 00:00:00| ITA| Y| 3| 0|
Upvotes: 1