Reputation: 434
This is my dataframe:
# +--------+--------+------+------+----------+------------+-------------------+-------------+
# |Location| Month| Brand|Sector| TrueValue|PickoutValue| month_in_timestamp|TotalSumValue|
# +--------+--------+------+------+----------+------------+-------------------+-------------+
# | USA|1/1/2021|brand2| cars2|16383104.2| 16666667|2021-01-01 00:00:00| 16383104.2|
# | USA|2/1/2021|brand2| cars2|26812874.2| 16666667|2021-01-02 00:00:00| 43195978.4|
# | USA|3/1/2021|brand2| cars2| null| null|2021-01-03 00:00:00| 43195978.4|
# | USA|1/1/2021|brand3| cars3| 75.6| 70.0|2021-01-01 00:00:00| 75.6|
# | USA|2/1/2021|brand3| cars3| 77.1| 70.0|2021-01-02 00:00:00| 76.4|
# | USA|3/1/2021|brand3| cars3| 73.1| 70.0|2021-01-03 00:00:00| 75.3|
# | USA|4/1/2021|brand3| cars3| 0.0| 45.3|2021-01-04 00:00:00| 56.4|
# | USA|5/1/2021|brand3| cars3| 0.0| 67.9|2021-01-05 00:00:00| 45.1|
Here Im deriving TotalSumValue column by calculating the cumulative sum of "brand2" rows and cumulative average of "brand3" rows. Im having data for the months Jan to Dec for all different brands.
While calculating the cumulative average for brand3, in TrueValue column, I have "0.0" values for the upcoming months. Currently, even "0.0" is also considered for calculating the cumulative average. Instead, I need to consider only the months having values other than "0.0".
So, my expected dataframe is:
# +--------+--------+------+------+----------+------------+-------------------+-------------+
# |Location| Month| Brand|Sector| TrueValue|PickoutValue| month_in_timestamp|TotalSumValue|
# +--------+--------+------+------+----------+------------+-------------------+-------------+
# | USA|1/1/2021|brand2| cars2|16383104.2| 16666667|2021-01-01 00:00:00| 16383104.2|
# | USA|2/1/2021|brand2| cars2|26812874.2| 16666667|2021-01-02 00:00:00| 43195978.4|
# | USA|3/1/2021|brand2| cars2| null| null|2021-01-03 00:00:00| 43195978.4|
# | USA|1/1/2021|brand3| cars3| 75.6| 70.0|2021-01-01 00:00:00| 75.6|
# | USA|2/1/2021|brand3| cars3| 77.1| 70.0|2021-01-02 00:00:00| 76.4|
# | USA|3/1/2021|brand3| cars3| 73.1| 70.0|2021-01-03 00:00:00| 75.3|
# | USA|4/1/2021|brand3| cars3| 0.0| 45.3|2021-01-04 00:00:00| 0.0|
# | USA|5/1/2021|brand3| cars3| 0.0| 67.9|2021-01-05 00:00:00| 0.0|
This my my code block:
windowval=(Window.partitionBy('Location','Brand').orderBy('month_in_timestamp')
.rangeBetween(Window.unboundedPreceding, 0))
df = df.withColumn('TotalSumValue',
F.when(F.col('Brand').isin('brand2'), F.sum('TrueValue').over(windowval)) \
when(F.col('Brand').isin('brand3'), F.avg('TrueValue').over(windowval)))
Upvotes: 1
Views: 96
Reputation: 696
You missed a dot (well, I missed it here as well but you could have found the problem, or ask within comments ;-) )
windowval=(Window.partitionBy('Location','Brand').orderBy('month_in_timestamp')
.rangeBetween(Window.unboundedPreceding, 0))
df = df.withColumn('TotalSumValue',
F.when(F.col('Brand').isin('brand2'), F.sum('TrueValue').over(windowval)) \
.when(F.col('Brand').isin('brand3'), F.avg('TrueValue').over(windowval)))
Upvotes: 1