user175025
user175025

Reputation: 434

Find cumulative average based on column values in spark dataframe

This is my dataframe:

# +--------+--------+------+------+----------+------------+-------------------+-------------+
# |Location|   Month| Brand|Sector| TrueValue|PickoutValue| month_in_timestamp|TotalSumValue|
# +--------+--------+------+------+----------+------------+-------------------+-------------+
# |     USA|1/1/2021|brand2| cars2|16383104.2|    16666667|2021-01-01 00:00:00|   16383104.2|
# |     USA|2/1/2021|brand2| cars2|26812874.2|    16666667|2021-01-02 00:00:00|   43195978.4|
# |     USA|3/1/2021|brand2| cars2|      null|        null|2021-01-03 00:00:00|   43195978.4|
# |     USA|1/1/2021|brand3| cars3|      75.6|        70.0|2021-01-01 00:00:00|         75.6|
# |     USA|2/1/2021|brand3| cars3|      77.1|        70.0|2021-01-02 00:00:00|         76.4|
# |     USA|3/1/2021|brand3| cars3|      73.1|        70.0|2021-01-03 00:00:00|         75.3|
# |     USA|4/1/2021|brand3| cars3|       0.0|        45.3|2021-01-04 00:00:00|         56.4|
# |     USA|5/1/2021|brand3| cars3|       0.0|        67.9|2021-01-05 00:00:00|         45.1|

Here Im deriving TotalSumValue column by calculating the cumulative sum of "brand2" rows and cumulative average of "brand3" rows. Im having data for the months Jan to Dec for all different brands.

While calculating the cumulative average for brand3, in TrueValue column, I have "0.0" values for the upcoming months. Currently, even "0.0" is also considered for calculating the cumulative average. Instead, I need to consider only the months having values other than "0.0".

So, my expected dataframe is:

# +--------+--------+------+------+----------+------------+-------------------+-------------+
# |Location|   Month| Brand|Sector| TrueValue|PickoutValue| month_in_timestamp|TotalSumValue|
# +--------+--------+------+------+----------+------------+-------------------+-------------+
# |     USA|1/1/2021|brand2| cars2|16383104.2|    16666667|2021-01-01 00:00:00|   16383104.2|
# |     USA|2/1/2021|brand2| cars2|26812874.2|    16666667|2021-01-02 00:00:00|   43195978.4|
# |     USA|3/1/2021|brand2| cars2|      null|        null|2021-01-03 00:00:00|   43195978.4|
# |     USA|1/1/2021|brand3| cars3|      75.6|        70.0|2021-01-01 00:00:00|         75.6|
# |     USA|2/1/2021|brand3| cars3|      77.1|        70.0|2021-01-02 00:00:00|         76.4|
# |     USA|3/1/2021|brand3| cars3|      73.1|        70.0|2021-01-03 00:00:00|         75.3|
# |     USA|4/1/2021|brand3| cars3|       0.0|        45.3|2021-01-04 00:00:00|          0.0|
# |     USA|5/1/2021|brand3| cars3|       0.0|        67.9|2021-01-05 00:00:00|          0.0|

This my my code block:

windowval=(Window.partitionBy('Location','Brand').orderBy('month_in_timestamp')
               .rangeBetween(Window.unboundedPreceding, 0))

df = df.withColumn('TotalSumValue',
         F.when(F.col('Brand').isin('brand2'), F.sum('TrueValue').over(windowval)) \
         when(F.col('Brand').isin('brand3'), F.avg('TrueValue').over(windowval)))

Upvotes: 1

Views: 96

Answers (1)

Christophe
Christophe

Reputation: 696

You missed a dot (well, I missed it here as well but you could have found the problem, or ask within comments ;-) )

windowval=(Window.partitionBy('Location','Brand').orderBy('month_in_timestamp')
               .rangeBetween(Window.unboundedPreceding, 0))

df = df.withColumn('TotalSumValue',
           F.when(F.col('Brand').isin('brand2'), F.sum('TrueValue').over(windowval)) \
           .when(F.col('Brand').isin('brand3'), F.avg('TrueValue').over(windowval)))

Upvotes: 1

Related Questions