Matteo
Matteo

Reputation: 137

Pyspark: sum over a window based on a condition

Consider the simple DataFrame:

from pyspark import SparkContext
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *
from pyspark.sql.functions import pandas_udf, PandasUDFType
spark = SparkSession.builder.appName('Trial').getOrCreate()

simpleData = (("2000-04-17", "144", 1), \
    ("2000-07-06", "015", 1),  \
    ("2001-01-23", "015", -1),   \
    ("2001-01-18", "144", -1),  \
    ("2001-04-17", "198", 1),    \
    ("2001-04-18", "036", -1),  \
    ("2001-04-19", "012", -1),    \
    ("2001-04-19", "188", 1), \
    ("2001-04-25", "188", 1),\
    ("2001-04-27", "015", 1) \
  )
 
columns= ["dates", "id", "eps"]
df = spark.createDataFrame(data = simpleData, schema = columns)
df.printSchema()
df.show(truncate=False)

Out:

root
 |-- dates: string (nullable = true)
 |-- id: string (nullable = true)
 |-- eps: long (nullable = true)

+----------+---+---+
|dates     |id |eps|
+----------+---+---+
|2000-04-17|144|1  |
|2000-07-06|015|1  |
|2001-01-23|015|-1 |
|2001-01-18|144|-1 |
|2001-04-17|198|1  |
|2001-04-18|036|-1 |
|2001-04-19|012|-1 |
|2001-04-19|188|1  |
|2001-04-25|188|1  |
|2001-04-27|015|1  |
+----------+---+---+

I would like to sum the values in the eps column over a rolling window keeping only the last value for any given ID in the id column. For example, defining a window of 5 rows and assuming we are on 2001-04-17, I want to sum only the last eps value for each given unique ID. In the 5 rows we have only 3 different ID, so the sum must be of 3 elements: -1 for the ID 144 (forth row), -1 for the ID 015 (third row) and 1 for the ID 198 (fifth row) for a total of -1.

In my mind, within the rolling window I should do something like F.sum(groupBy('id').agg(F.last('eps'))) that of course is not possible to achieve in a rolling window.

I obtained the desired result using a UDF.

@pandas_udf(IntegerType(), PandasUDFType.GROUPEDAGG)
def fun_sum(id, eps):
    df = pd.DataFrame()
    df['id'] = id
    df['eps'] = eps
    value = df.groupby('id').last().sum()
    return value

And then:

w = Window.orderBy('dates').rowsBetween(-5,0)
df = df.withColumn('sum', fun_sum(F.col('id'), F.col('eps')).over(w))

The problem is that my dataset contains more than 8 milion rows and performing this task with this UDF takes about 2 hours.

I was wandering whether there is a way to achieve the same result with built-in PySpark functions avoiding using a UDF or at least whether there is a way to improve the performance of my UDF.

For completeness, the desired output should be:

+----------+---+---+----+
|dates     |id |eps|sum |
+----------+---+---+----+
|2000-04-17|144|1  |1   |
|2000-07-06|015|1  |2   |
|2001-01-23|015|-1 |0   |
|2001-01-18|144|-1 |-2  |
|2001-04-17|198|1  |-1  |
|2001-04-18|036|-1 |-2  |
|2001-04-19|012|-1 |-3  |
|2001-04-19|188|1  |-1  |
|2001-04-25|188|1  |0   |
|2001-04-27|015|1  |0   |
+----------+---+---+----+

EDIT: the rseult must also be achievable using a .rangeBetween() window.

Upvotes: 0

Views: 5864

Answers (1)

Paul P
Paul P

Reputation: 3907

In case you haven't figured it out yet, here's one way of achieving it.

Assuming that df is defined and initialised the way you defined and initialised it in your question.

Import the required functions and classes:

from pyspark.sql.functions import row_number, col
from pyspark.sql.window import Window

Create the necessary WindowSpec:

window_spec = (
    Window
    # Partition by 'id'.
    .partitionBy(df.id)
    # Order by 'dates', latest dates first.
    .orderBy(df.dates.desc())
)

Create a DataFrame with partitioned data:

partitioned_df = (
    df
    # Use the window function 'row_number()' to populate a new column
    # containing a sequential number starting at 1 within a window partition.
    .withColumn('row', row_number().over(window_spec))
    # Only select the first entry in each partition (i.e. the latest date).
    .where(col('row') == 1)
)

Just in case you want to double-check the data:

partitioned_df.show()

# +----------+---+---+---+
# |     dates| id|eps|row|
# +----------+---+---+---+
# |2001-04-19|012| -1|  1|
# |2001-04-25|188|  1|  1|
# |2001-04-27|015|  1|  1|
# |2001-04-17|198|  1|  1|
# |2001-01-18|144| -1|  1|
# |2001-04-18|036| -1|  1|
# +----------+---+---+---+

Group and aggregate the data:

sum_rows = (
    partitioned_df
    # Aggragate data.
    .groupBy()
    # Sum all rows in 'eps' column.
    .sum('eps')
    # Get all records as a list of Rows.
    .collect()
)

Get the result:

print(f"sum eps: {sum_rows[0][0]})
# sum eps: 0

Upvotes: 2

Related Questions