Dhruv
Dhruv

Reputation: 445

pivot vs window in spark

I have the following requirement

  1. Pivot the dataframe to sum amount column based on document type
  2. Join the pivot dataframe back to the original dataframe to get additional columns
  3. Filter the joined dataframe using window function

Sample code

Setting up the dataframe

from datetime import date
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DateType
import pyspark.sql.functions as F
from pyspark.sql.window import Window

schema = StructType([
    StructField('company_code', StringType(), True)
    , StructField('line_no', IntegerType(), True)
    , StructField('document_type', StringType(), True)
    , StructField('amount', IntegerType(), True)
    , StructField('posting_date', DateType(), True)
])

data = [
    ['AB', 10, 'RE', 12, date(2019,1,1)]
    , ['AB', 10, 'RE', 13, date(2019,2,10)]
    , ['AB', 20, 'WE', 14, date(2019,1,11)]
    , ['BC', 10, 'WL', 11, date(2019,2,12)]
    , ['BC', 20, 'RE', 15, date(2019,1,21)]
]

df = spark.createDataFrame(data, schema)

First using the pivot way

# Partitioning upfront so as to not shuffle twice(one in groupby and other in window)
partition_df = df.repartition('company_code', 'line_no').cache()

pivot_df = (
    partition_df.groupBy('company_code', 'line_no')
    .pivot('document_type', ['RE', 'WE', 'WL'])
    .sum('amount')
)

# It will broadcast join because pivot_df is small (it is small for my actual case as well)
join_df = (
    partition_df.join(pivot_df, ['company_code', 'line_no'])
    .select(partition_df['*'], 'RE', 'WE', 'WL')
)

window_spec = Window.partitionBy('company_code', 'line_no').orderBy('posting_date')

final_df = join_df.withColumn("Row_num", F.row_number().over(window_spec)).filter("Row_num == 1").drop("Row_num")

final_df.show()
+------------+-------+-------------+------+------------+----+----+----+
|company_code|line_no|document_type|amount|posting_date|  RE|  WE|  WL|
+------------+-------+-------------+------+------------+----+----+----+
|          AB|     10|           RE|    12|  2019-01-01|  25|NULL|NULL|
|          AB|     20|           WE|    14|  2019-01-11|NULL|  14|NULL|
|          BC|     10|           WL|    11|  2019-02-12|NULL|NULL|  11|
|          BC|     20|           RE|    15|  2019-01-21|  15|NULL|NULL|
+------------+-------+-------------+------+------------+----+----+----+

And using the window way

t_df = df.withColumns({
    'RE': F.when(F.col('document_type') == 'RE', F.col('amount')).otherwise(0)
    , 'WE': F.when(F.col('document_type') == 'WE', F.col('amount')).otherwise(0)
    , 'WL': F.when(F.col('document_type') == 'WL', F.col('amount')).otherwise(0)
})

window_spec = Window.partitionBy('company_code', 'line_no').orderBy('posting_date')
sum_window_spec = window_spec.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

t2_df = t_df.withColumns({
    'RE': F.sum('RE').over(sum_window_spec)
    , 'WE': F.sum('WE').over(sum_window_spec)
    , 'WL': F.sum('WL').over(sum_window_spec)
    , 'Row_num': F.row_number().over(window_spec)
})

final_df = t2_df.filter("Row_num == 1").drop("Row_num")

final_df.show()
+------------+-------+-------------+------+------------+---+---+---+
|company_code|line_no|document_type|amount|posting_date| RE| WE| WL|
+------------+-------+-------------+------+------------+---+---+---+
|          AB|     10|           RE|    12|  2019-01-01| 25|  0|  0|
|          AB|     20|           WE|    14|  2019-01-11|  0| 14|  0|
|          BC|     10|           WL|    11|  2019-02-12|  0|  0| 11|
|          BC|     20|           RE|    15|  2019-01-21| 15|  0|  0|
+------------+-------+-------------+------+------------+---+---+---+

I have not put the output of explain here as it will make the question lengthy. But, there is only one shuffle in both methods. So, how to decide which one will take more time?

I'm using databricks runtime 14.3LTS

Upvotes: 1

Views: 63

Answers (1)

lihao
lihao

Reputation: 781

If you have a very large data set(especially skewed data), you might want to select GroupBy/Pivot over Window function. Even though both approaches have triggered shuffling only once, but the size of rows involved could have a big difference.

Below is a snippet from the Explain with your first approach using groupBy/Pivot:

+- HashAggregate(keys=[company_code#6378, line_no#6379, document_type#6380], functions=[sum(amount#6381)])
    +- HashAggregate(keys=[company_code#6378, line_no#6379, document_type#6380], functions=[partial_sum(amount#6381)])

You can see above from functions=[partial_sum(amount#6381)] to functions=[sum(amount#6381)], this means that the Spark handles aggregation using a method similar to rdd.map(..).reduceByKey(add), taking sum in the local partition and then merging intermediate results between partitions. This significantly reduced the number of rows to shuffle. You can verify this through Spark WebUI, for example SQL/DataFrame tab and find your SQL entry and check the Exchange box for shuffle records written (make sure to use a larger dataframe for testing not just 5 rows dataframe). Pivot takes the similar approach (from partial_pivotsum to pivot_sum).

On the other hand, using Window function will first move all related rows to the same logical partition set up by WindowSpec and then do aggregation once(similar to rdd.map(..).groupByKey().mapValues().sum() and no partial_sum). thus all rows might get involved in the data moving before doing aggregation(check WebUI for shuffle records written). This is especially bad if your data are skewed which could lead to OOM issue.

My observation is only from the data shuffling point of view, there are many other factors which might influence the query performance which you will have to check on your own.

One more suggestion with your code, I think you can skip join, Window function and pivot in your first method and combine some code in your second method:

final_df = (df.groupby('company_code','line_no','document_type') 
    .agg(
        F.sum('amount').alias('amount'),
        F.min('posting_date').alias('posting_date') 
    ).withColumns({ dt:F.when(F.col('document_type') == dt, F.col('amount')).otherwise(0) for dt in ['RE', 'WE', 'WL']})
)

final_df.show()
+------------+-------+-------------+------+------------+---+---+---+
|company_code|line_no|document_type|amount|posting_date| RE| WE| WL|
+------------+-------+-------------+------+------------+---+---+---+
|          AB|     10|           RE|    25|  2019-01-01| 25|  0|  0|
|          BC|     10|           WL|    11|  2019-02-12|  0|  0| 11|
|          AB|     20|           WE|    14|  2019-01-11|  0| 14|  0|
|          BC|     20|           RE|    15|  2019-01-21| 15|  0|  0|
+------------+-------+-------------+------+------------+---+---+---+

Upvotes: 0

Related Questions