Hunter Mitchell
Hunter Mitchell

Reputation: 399

Adding column with sum of all rows above in same grouping

I need to create a 'rolling count' column which takes the previous count and adds the new count for each day and company. I have already organized and sorted the dataframe into groups of ascending dates per company with the corresponding count. I also added a 'ix' column which indexes each grouping, like so:

+--------------------+--------------------+-----+---+
|     Normalized_Date|             company|count| ix|
+--------------------+--------------------+-----+---+
|09/25/2018 00:00:...|[5c40c8510fb7c017...|    7|  1|
|09/25/2018 00:00:...|[5bdb2b543951bf07...|    9|  1|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...|    7|  1|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...|   60|  2|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...|    1|  3|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...|    9|  4|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...|   29|  5|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...|   42|  6|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...|  317|  7|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...|    3|  8|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...|   15|  9|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...|    1| 10|
+--------------------+--------------------+-----+---+

The output I need would simply add up all the counts up to that date for each company. Like so:

+--------------------+--------------------+-----+---+------------+
|     Normalized_Date|             company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...|    7|  1|           7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...|    9|  1|           9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...|    7|  1|           7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...|   60|  2|          67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...|    1|  3|          68|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...|    9|  4|          77|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...|   29|  5|         106|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...|   42|  6|         148|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...|  317|  7|         465|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...|    3|  8|         468|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...|   15|  9|         483|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...|    1| 10|         484|
+--------------------+--------------------+-----+---+------------+

I figured the lag function would be of use, and I was able to get each row of rollingcount with ix > 1 to add the count directly above it with the following code:

w = Window.partitionBy('company').orderBy(F.unix_timestamp('Normalized_Dat e','MM/dd/yyyy HH:mm:ss aaa').cast('timestamp'))
refined_DF = solutionDF.withColumn("rn", F.row_number().over(w))
solutionDF = refined_DF.withColumn('RollingCount',F.when(refined_DF['rn'] > 1, refined_DF['count'] + F.lag(refined_DF['count'],count= 1 ).over(w)).otherwise(refined_DF['count']))

which yields the following df:

+--------------------+--------------------+-----+---+------------+
|     Normalized_Date|             company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...|    7|  1|           7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...|    9|  1|           9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...|    7|  1|           7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...|   60|  2|          67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...|    1|  3|          61|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...|    9|  4|          10|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...|   29|  5|          38|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...|   42|  6|          71|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...|  317|  7|         359|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...|    3|  8|         320|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...|   15|  9|          18|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...|    1| 10|          16|
+--------------------+--------------------+-----+---+------------+

I just need it to sum all of the counts ix rows above it. I have tried using a udf to figure out the 'count' input into the lag function, but I keep getting a "'Column' object is not callable" error, plus it doesn't do the sum of all of the rows. I have also tried using a loop but that seems impossible because it will make a new dataframe each time through, plus I would need to join them all afterwards. There must be an easier and simpler way to do this. Perhaps a different function than lag?

Upvotes: 2

Views: 1917

Answers (2)

cronoik
cronoik

Reputation: 19320

The lag returns you a certain single row before your current value, but you need a range to calculate the cummulative sum. Therefore you have to use the window function rangeBetween (rowsBetween). Have a look at the example below:

import pyspark.sql.functions as F
from pyspark.sql import Window

l = [
('09/25/2018', '5c40c8510fb7c017',  7,  1),
('09/25/2018', '5bdb2b543951bf07',    9,  1),
('11/28/2017', '593b0d9f3f21f9dd',     7,  1),
('11/29/2017', '593b0d9f3f21f9dd',    60,  2),
('01/09/2018', '593b0d9f3f21f9dd',     1,  3),
('04/27/2018', '593b0d9f3f21f9dd',     9,  4),
('09/25/2018', '593b0d9f3f21f9dd',    29,  5),
('11/20/2018', '593b0d9f3f21f9dd',    42,  6),
('12/11/2018', '593b0d9f3f21f9dd',   317,  7),
('01/04/2019', '593b0d9f3f21f9dd',     3,  8),
('02/13/2019', '593b0d9f3f21f9dd',    15,  9),
('04/01/2019', '593b0d9f3f21f9dd',     1, 10)
]

columns = ['Normalized_Date', 'company','count', 'ix']

df=spark.createDataFrame(l, columns)

df = df.withColumn('Normalized_Date', F.to_date(df.Normalized_Date, 'MM/dd/yyyy'))

w = Window.partitionBy('company').orderBy('Normalized_Date').rangeBetween(Window.unboundedPreceding, 0)

df = df.withColumn('Rolling_count', F.sum('count').over(w))
df.show()

Output:

+---------------+----------------+-----+---+-------------+ 
|Normalized_Date|         company|count| ix|Rolling_count| 
+---------------+----------------+-----+---+-------------+ 
|     2018-09-25|5c40c8510fb7c017|    7|  1|            7| 
|     2018-09-25|5bdb2b543951bf07|    9|  1|            9| 
|     2017-11-28|593b0d9f3f21f9dd|    7|  1|            7| 
|     2017-11-29|593b0d9f3f21f9dd|   60|  2|           67| 
|     2018-01-09|593b0d9f3f21f9dd|    1|  3|           68| 
|     2018-04-27|593b0d9f3f21f9dd|    9|  4|           77| 
|     2018-09-25|593b0d9f3f21f9dd|   29|  5|          106| 
|     2018-11-20|593b0d9f3f21f9dd|   42|  6|          148| 
|     2018-12-11|593b0d9f3f21f9dd|  317|  7|          465| 
|     2019-01-04|593b0d9f3f21f9dd|    3|  8|          468| 
|     2019-02-13|593b0d9f3f21f9dd|   15|  9|          483| 
|     2019-04-01|593b0d9f3f21f9dd|    1| 10|          484| 
+---------------+----------------+-----+---+-------------+

Upvotes: 1

C.S.Reddy Gadipally
C.S.Reddy Gadipally

Reputation: 1758

try this. You need the sum of all preceding rows to current row in the window frame.

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.expressions.WindowSpec
import org.apache.spark.sql.functions._

val df = Seq(
("5c40c8510fb7c017", 7, 1),
("5bdb2b543951bf07", 9, 1),
("593b0d9f3f21f9dd", 7, 1),
("593b0d9f3f21f9dd", 60, 2),
("593b0d9f3f21f9dd", 1, 3),
("593b0d9f3f21f9dd", 9, 4),
("593b0d9f3f21f9dd", 29, 5),
("593b0d9f3f21f9dd", 42, 6),
("593b0d9f3f21f9dd", 317, 7),
("593b0d9f3f21f9dd", 3, 8),
("593b0d9f3f21f9dd", 15, 9),
("593b0d9f3f21f9dd", 1, 10)
).toDF("company", "count", "ix")

scala> df.show(false)
+----------------+-----+---+
|company         |count|ix |
+----------------+-----+---+
|5c40c8510fb7c017|7    |1  |
|5bdb2b543951bf07|9    |1  |
|593b0d9f3f21f9dd|7    |1  |
|593b0d9f3f21f9dd|60   |2  |
|593b0d9f3f21f9dd|1    |3  |
|593b0d9f3f21f9dd|9    |4  |
|593b0d9f3f21f9dd|29   |5  |
|593b0d9f3f21f9dd|42   |6  |
|593b0d9f3f21f9dd|317  |7  |
|593b0d9f3f21f9dd|3    |8  |
|593b0d9f3f21f9dd|15   |9  |
|593b0d9f3f21f9dd|1    |10 |
+----------------+-----+---+


scala> val overColumns = Window.partitionBy("company").orderBy("ix").rowsBetween(Window.unboundedPreceding, Window.currentRow)
overColumns: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@3ed5e17c

scala> val outputDF = df.withColumn("RollingCount", sum("count").over(overColumns))
outputDF: org.apache.spark.sql.DataFrame = [company: string, count: int ... 2 more fields]

scala> outputDF.show(false)
+----------------+-----+---+------------+
|company         |count|ix |RollingCount|
+----------------+-----+---+------------+
|5c40c8510fb7c017|7    |1  |7           |
|5bdb2b543951bf07|9    |1  |9           |
|593b0d9f3f21f9dd|7    |1  |7           |
|593b0d9f3f21f9dd|60   |2  |67          |
|593b0d9f3f21f9dd|1    |3  |68          |
|593b0d9f3f21f9dd|9    |4  |77          |
|593b0d9f3f21f9dd|29   |5  |106         |
|593b0d9f3f21f9dd|42   |6  |148         |
|593b0d9f3f21f9dd|317  |7  |465         |
|593b0d9f3f21f9dd|3    |8  |468         |
|593b0d9f3f21f9dd|15   |9  |483         |
|593b0d9f3f21f9dd|1    |10 |484         |
+----------------+-----+---+------------+

Upvotes: 0

Related Questions