
Reputation: 145

PySpark Sum calculation within a date range

I have a pyspark.DataFrame like below, with columns id, year and money. Here for simplicity, I have taken only one id but there could be multiples.

id    year    money

1      2019    10

1      2018    15

1      2013    13

1      2009    10

1      2015    10

1      2014    11

In the resultant DataFrame for each id and period I want the sum of money for previous 3 consecutive years excluding the record year.

For example, for the year 2019, I want to take the sum of money for 2018, 2017 and 2016 only. Since we only have 2018, the sum would be 15.

Another case like for the year 2015, I want to take the sum of money for 2014, 2013 and 2012. Since only the first 2 are present, it would sum to 24.

The resulting DataFrame would look like below.

id     year    sum_money     

1      2019      15   

1      2018      10

1      2015      24

1      2014      13

1      2013      0

1      2009      0 

How can I achieve the desired results. Does the lag function provide any such functionality to look for only those years i want, or is there any other approach.

My Approach

My approach is taking the cum sum over the years, ordered by years in descending. Then for each id and year, find the maximum year which is just less than expected - window.

Say for the year 2019 and window = 3, the start year would be 2016. So the minimum year present in the dataset, which is 2015, is what we have to take. Corresponding to 2015 fill the cum_sum for that year.

Then in the final result column take the difference of two cum sum and value of current year sum. So for 2019 it would be 69 - 44 - 10 = 15. Same for other records(id & year) as well. Final data would look like below.

id    year    money   cum_sum     min_year    res_sum    diff
1      2019    10         69        2015        44        15
1      2018    15         59        2014        34        10
1      2015    10         44        2009        10        24
1      2014    11         34        2009        10        13
1      2013    13         23        2009        10        0
1      2009    10         10        0            0        0

I am trying to figure out a simpler approach.

Upvotes: 3

Views: 1344

Answers (1)

Sebastian Wozny
Sebastian Wozny

Reputation: 17526

In pyspark we can use rangeBetween as pointed out by @samkart:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum
from pyspark.sql.window import Window

data = [
    {'id': 1, 'year': 2019, 'money': 10},
    {'id': 1, 'year': 2018, 'money': 15},
    {'id': 1, 'year': 2013, 'money': 13},
    {'id': 1, 'year': 2009, 'money': 10},
    {'id': 1, 'year': 2015, 'money': 10},
    {'id': 1, 'year': 2014, 'money': 11}

spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(data)

# Calculate the sum over the preceding 3 years
window = Window.partitionBy("id").orderBy("year").rangeBetween(-3, -1)
df = df.withColumn("previous_3_years_sum", sum("money").over(window))
df = df.fillna(0, subset=["previous_3_years_sum"])


| id|money|year|previous_3_years_sum|
|  1|   10|2009|                   0|
|  1|   13|2013|                   0|
|  1|   11|2014|                  13|
|  1|   10|2015|                  24|
|  1|   15|2018|                  10|
|  1|   10|2019|                  15|

This solution also strikes me as much more elegant. Window specs are pretty flexible and powerful. We don't need to create fake entries and no shifting etc is necessary. Initially I provided a solution using pandas:

import pandas as pd

data = {
    'id': [1, 1, 1, 1, 1, 1],
    'year': [2019, 2018, 2013, 2009, 2015, 2014],
    'money': [10, 15, 13, 10, 10, 11]

df = pd.DataFrame(data)

#Fill in the missing years with 0s
df.set_index('year', inplace=True)
min_year = df.index.min()
max_year = df.index.max()
all_years = range(min_year, max_year + 1)
df = df.reindex(all_years).fillna(0)
df = df.reset_index().rename(columns={'index': 'year'})

#Sort by years, then cumsum shift by 3, diff and shift again to align the df
df['previous_3_years_sum']=(df.cum -df.cum.shift(3)).shift(1).fillna(0)
df=df.query('money>0')[['id','year','previous_3_years_sum']] #Filter for artifically inserted entries again


    id  year    previous_3_years_sum
0   1.0 2009    0.0
4   1.0 2013    0.0
5   1.0 2014    13.0
6   1.0 2015    24.0
9   1.0 2018    10.0
10  1.0 2019    15.0


I think the real question is about getting it for quarter specific reports. I solved this with a mapping. We multiply the year column by 4 to make space for 4 quarters: 2000 is mapped to 8000, and 8000 now represents Q1,8001 Q2 etc. Then we can use rangeBetween for 12 quarters.

from pyspark.sql import SparkSession
from pyspark.sql.functions import split
from pyspark.sql.functions import regexp_replace

# Create SparkSession
spark = SparkSession.builder.getOrCreate()

# Create the data
data = [
    (1, 10, "2009-Q1"),
    (1, 13, "2013-Q1"),
    (1, 11, "2014-Q1"),
    (1, 10, "2015-Q1"),
    (1, 15, "2018-Q1"),
    (1, 10, "2019-Q1")

# Create the DataFrame
df = spark.createDataFrame(data, ["id", "money", "quarter"])

# Split the quarter column into year and quarter columns
df = df.withColumn("year", split(df["quarter"], "-").getItem(0))
df = df.withColumn("quarter", split(df["quarter"], "-").getItem(1))
df = df.withColumn("quarter", regexp_replace(df["quarter"], "Q", "").cast("int"))
# Add a new column for the calculated value
df = df.withColumn("new_column", (df["year"] * 4) + (df["quarter"] - 1).cast("int"))
window_spec = Window.partitionBy("id").orderBy("new_column").rangeBetween(-12, -1)
df = df.withColumn("previous_3_years_sum", sum("money").over(window_spec))

# Fill null values with 0
df = df.fillna(0, subset=["previous_3_years_sum"])


| id|money|quarter|year|new_column|previous_3_years_sum|
|  1|   10|      1|2009|    8036.0|                   0|
|  1|   13|      1|2013|    8052.0|                   0|
|  1|   11|      1|2014|    8056.0|                  13|
|  1|   10|      1|2015|    8060.0|                  24|
|  1|   15|      1|2018|    8072.0|                  10|
|  1|   10|      1|2019|    8076.0|                  15|

Upvotes: 1

Related Questions