Reputation: 399
I need to create a 'rolling count' column which takes the previous count and adds the new count for each day and company. I have already organized and sorted the dataframe into groups of ascending dates per company with the corresponding count. I also added a 'ix' column which indexes each grouping, like so:
+--------------------+--------------------+-----+---+
| Normalized_Date| company|count| ix|
+--------------------+--------------------+-----+---+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10|
+--------------------+--------------------+-----+---+
The output I need would simply add up all the counts up to that date for each company. Like so:
+--------------------+--------------------+-----+---+------------+
| Normalized_Date| company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1| 7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1| 9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1| 7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2| 67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3| 68|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4| 77|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5| 106|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6| 148|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7| 465|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8| 468|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9| 483|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10| 484|
+--------------------+--------------------+-----+---+------------+
I figured the lag function would be of use, and I was able to get each row of rollingcount with ix > 1 to add the count directly above it with the following code:
w = Window.partitionBy('company').orderBy(F.unix_timestamp('Normalized_Dat e','MM/dd/yyyy HH:mm:ss aaa').cast('timestamp'))
refined_DF = solutionDF.withColumn("rn", F.row_number().over(w))
solutionDF = refined_DF.withColumn('RollingCount',F.when(refined_DF['rn'] > 1, refined_DF['count'] + F.lag(refined_DF['count'],count= 1 ).over(w)).otherwise(refined_DF['count']))
which yields the following df:
+--------------------+--------------------+-----+---+------------+
| Normalized_Date| company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1| 7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1| 9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1| 7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2| 67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3| 61|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4| 10|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5| 38|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6| 71|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7| 359|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8| 320|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9| 18|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10| 16|
+--------------------+--------------------+-----+---+------------+
I just need it to sum all of the counts ix rows above it. I have tried using a udf to figure out the 'count' input into the lag function, but I keep getting a "'Column' object is not callable" error, plus it doesn't do the sum of all of the rows. I have also tried using a loop but that seems impossible because it will make a new dataframe each time through, plus I would need to join them all afterwards. There must be an easier and simpler way to do this. Perhaps a different function than lag?
Upvotes: 2
Views: 1917
Reputation: 19320
The lag returns you a certain single row before your current value, but you need a range to calculate the cummulative sum. Therefore you have to use the window function rangeBetween (rowsBetween). Have a look at the example below:
import pyspark.sql.functions as F
from pyspark.sql import Window
l = [
('09/25/2018', '5c40c8510fb7c017', 7, 1),
('09/25/2018', '5bdb2b543951bf07', 9, 1),
('11/28/2017', '593b0d9f3f21f9dd', 7, 1),
('11/29/2017', '593b0d9f3f21f9dd', 60, 2),
('01/09/2018', '593b0d9f3f21f9dd', 1, 3),
('04/27/2018', '593b0d9f3f21f9dd', 9, 4),
('09/25/2018', '593b0d9f3f21f9dd', 29, 5),
('11/20/2018', '593b0d9f3f21f9dd', 42, 6),
('12/11/2018', '593b0d9f3f21f9dd', 317, 7),
('01/04/2019', '593b0d9f3f21f9dd', 3, 8),
('02/13/2019', '593b0d9f3f21f9dd', 15, 9),
('04/01/2019', '593b0d9f3f21f9dd', 1, 10)
]
columns = ['Normalized_Date', 'company','count', 'ix']
df=spark.createDataFrame(l, columns)
df = df.withColumn('Normalized_Date', F.to_date(df.Normalized_Date, 'MM/dd/yyyy'))
w = Window.partitionBy('company').orderBy('Normalized_Date').rangeBetween(Window.unboundedPreceding, 0)
df = df.withColumn('Rolling_count', F.sum('count').over(w))
df.show()
Output:
+---------------+----------------+-----+---+-------------+
|Normalized_Date| company|count| ix|Rolling_count|
+---------------+----------------+-----+---+-------------+
| 2018-09-25|5c40c8510fb7c017| 7| 1| 7|
| 2018-09-25|5bdb2b543951bf07| 9| 1| 9|
| 2017-11-28|593b0d9f3f21f9dd| 7| 1| 7|
| 2017-11-29|593b0d9f3f21f9dd| 60| 2| 67|
| 2018-01-09|593b0d9f3f21f9dd| 1| 3| 68|
| 2018-04-27|593b0d9f3f21f9dd| 9| 4| 77|
| 2018-09-25|593b0d9f3f21f9dd| 29| 5| 106|
| 2018-11-20|593b0d9f3f21f9dd| 42| 6| 148|
| 2018-12-11|593b0d9f3f21f9dd| 317| 7| 465|
| 2019-01-04|593b0d9f3f21f9dd| 3| 8| 468|
| 2019-02-13|593b0d9f3f21f9dd| 15| 9| 483|
| 2019-04-01|593b0d9f3f21f9dd| 1| 10| 484|
+---------------+----------------+-----+---+-------------+
Upvotes: 1
Reputation: 1758
try this. You need the sum of all preceding rows to current row in the window frame.
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.expressions.WindowSpec
import org.apache.spark.sql.functions._
val df = Seq(
("5c40c8510fb7c017", 7, 1),
("5bdb2b543951bf07", 9, 1),
("593b0d9f3f21f9dd", 7, 1),
("593b0d9f3f21f9dd", 60, 2),
("593b0d9f3f21f9dd", 1, 3),
("593b0d9f3f21f9dd", 9, 4),
("593b0d9f3f21f9dd", 29, 5),
("593b0d9f3f21f9dd", 42, 6),
("593b0d9f3f21f9dd", 317, 7),
("593b0d9f3f21f9dd", 3, 8),
("593b0d9f3f21f9dd", 15, 9),
("593b0d9f3f21f9dd", 1, 10)
).toDF("company", "count", "ix")
scala> df.show(false)
+----------------+-----+---+
|company |count|ix |
+----------------+-----+---+
|5c40c8510fb7c017|7 |1 |
|5bdb2b543951bf07|9 |1 |
|593b0d9f3f21f9dd|7 |1 |
|593b0d9f3f21f9dd|60 |2 |
|593b0d9f3f21f9dd|1 |3 |
|593b0d9f3f21f9dd|9 |4 |
|593b0d9f3f21f9dd|29 |5 |
|593b0d9f3f21f9dd|42 |6 |
|593b0d9f3f21f9dd|317 |7 |
|593b0d9f3f21f9dd|3 |8 |
|593b0d9f3f21f9dd|15 |9 |
|593b0d9f3f21f9dd|1 |10 |
+----------------+-----+---+
scala> val overColumns = Window.partitionBy("company").orderBy("ix").rowsBetween(Window.unboundedPreceding, Window.currentRow)
overColumns: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@3ed5e17c
scala> val outputDF = df.withColumn("RollingCount", sum("count").over(overColumns))
outputDF: org.apache.spark.sql.DataFrame = [company: string, count: int ... 2 more fields]
scala> outputDF.show(false)
+----------------+-----+---+------------+
|company |count|ix |RollingCount|
+----------------+-----+---+------------+
|5c40c8510fb7c017|7 |1 |7 |
|5bdb2b543951bf07|9 |1 |9 |
|593b0d9f3f21f9dd|7 |1 |7 |
|593b0d9f3f21f9dd|60 |2 |67 |
|593b0d9f3f21f9dd|1 |3 |68 |
|593b0d9f3f21f9dd|9 |4 |77 |
|593b0d9f3f21f9dd|29 |5 |106 |
|593b0d9f3f21f9dd|42 |6 |148 |
|593b0d9f3f21f9dd|317 |7 |465 |
|593b0d9f3f21f9dd|3 |8 |468 |
|593b0d9f3f21f9dd|15 |9 |483 |
|593b0d9f3f21f9dd|1 |10 |484 |
+----------------+-----+---+------------+
Upvotes: 0