Reputation: 445
I have the following requirement
Sample code
Setting up the dataframe
from datetime import date
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DateType
import pyspark.sql.functions as F
from pyspark.sql.window import Window
schema = StructType([
StructField('company_code', StringType(), True)
, StructField('line_no', IntegerType(), True)
, StructField('document_type', StringType(), True)
, StructField('amount', IntegerType(), True)
, StructField('posting_date', DateType(), True)
])
data = [
['AB', 10, 'RE', 12, date(2019,1,1)]
, ['AB', 10, 'RE', 13, date(2019,2,10)]
, ['AB', 20, 'WE', 14, date(2019,1,11)]
, ['BC', 10, 'WL', 11, date(2019,2,12)]
, ['BC', 20, 'RE', 15, date(2019,1,21)]
]
df = spark.createDataFrame(data, schema)
First using the pivot way
# Partitioning upfront so as to not shuffle twice(one in groupby and other in window)
partition_df = df.repartition('company_code', 'line_no').cache()
pivot_df = (
partition_df.groupBy('company_code', 'line_no')
.pivot('document_type', ['RE', 'WE', 'WL'])
.sum('amount')
)
# It will broadcast join because pivot_df is small (it is small for my actual case as well)
join_df = (
partition_df.join(pivot_df, ['company_code', 'line_no'])
.select(partition_df['*'], 'RE', 'WE', 'WL')
)
window_spec = Window.partitionBy('company_code', 'line_no').orderBy('posting_date')
final_df = join_df.withColumn("Row_num", F.row_number().over(window_spec)).filter("Row_num == 1").drop("Row_num")
final_df.show()
+------------+-------+-------------+------+------------+----+----+----+
|company_code|line_no|document_type|amount|posting_date| RE| WE| WL|
+------------+-------+-------------+------+------------+----+----+----+
| AB| 10| RE| 12| 2019-01-01| 25|NULL|NULL|
| AB| 20| WE| 14| 2019-01-11|NULL| 14|NULL|
| BC| 10| WL| 11| 2019-02-12|NULL|NULL| 11|
| BC| 20| RE| 15| 2019-01-21| 15|NULL|NULL|
+------------+-------+-------------+------+------------+----+----+----+
And using the window way
t_df = df.withColumns({
'RE': F.when(F.col('document_type') == 'RE', F.col('amount')).otherwise(0)
, 'WE': F.when(F.col('document_type') == 'WE', F.col('amount')).otherwise(0)
, 'WL': F.when(F.col('document_type') == 'WL', F.col('amount')).otherwise(0)
})
window_spec = Window.partitionBy('company_code', 'line_no').orderBy('posting_date')
sum_window_spec = window_spec.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
t2_df = t_df.withColumns({
'RE': F.sum('RE').over(sum_window_spec)
, 'WE': F.sum('WE').over(sum_window_spec)
, 'WL': F.sum('WL').over(sum_window_spec)
, 'Row_num': F.row_number().over(window_spec)
})
final_df = t2_df.filter("Row_num == 1").drop("Row_num")
final_df.show()
+------------+-------+-------------+------+------------+---+---+---+
|company_code|line_no|document_type|amount|posting_date| RE| WE| WL|
+------------+-------+-------------+------+------------+---+---+---+
| AB| 10| RE| 12| 2019-01-01| 25| 0| 0|
| AB| 20| WE| 14| 2019-01-11| 0| 14| 0|
| BC| 10| WL| 11| 2019-02-12| 0| 0| 11|
| BC| 20| RE| 15| 2019-01-21| 15| 0| 0|
+------------+-------+-------------+------+------------+---+---+---+
I have not put the output of explain
here as it will make the question lengthy. But, there is only one shuffle in both methods. So, how to decide which one will take more time?
I'm using databricks runtime 14.3LTS
Upvotes: 1
Views: 63
Reputation: 781
If you have a very large data set(especially skewed data), you might want to select GroupBy/Pivot
over Window
function. Even though both approaches have triggered shuffling only once, but the size of rows involved could have a big difference.
Below is a snippet from the Explain with your first approach using groupBy/Pivot:
+- HashAggregate(keys=[company_code#6378, line_no#6379, document_type#6380], functions=[sum(amount#6381)])
+- HashAggregate(keys=[company_code#6378, line_no#6379, document_type#6380], functions=[partial_sum(amount#6381)])
You can see above from functions=[partial_sum(amount#6381)]
to functions=[sum(amount#6381)]
, this means that the Spark handles aggregation using a method similar to rdd.map(..).reduceByKey(add)
, taking sum in the local partition and then merging intermediate results between partitions. This significantly reduced the number of rows to shuffle. You can verify this through Spark WebUI, for example SQL/DataFrame
tab and find your SQL entry and check the Exchange
box for shuffle records written
(make sure to use a larger dataframe for testing not just 5 rows dataframe). Pivot takes the similar approach (from partial_pivotsum
to pivot_sum
).
On the other hand, using Window function will first move all related rows to the same logical partition set up by WindowSpec and then do aggregation once(similar to rdd.map(..).groupByKey().mapValues().sum()
and no partial_sum
). thus all rows might get involved in the data moving before doing aggregation(check WebUI for shuffle records written
). This is especially bad if your data are skewed which could lead to OOM issue.
My observation is only from the data shuffling point of view, there are many other factors which might influence the query performance which you will have to check on your own.
One more suggestion with your code, I think you can skip join, Window function and pivot in your first method and combine some code in your second method:
final_df = (df.groupby('company_code','line_no','document_type')
.agg(
F.sum('amount').alias('amount'),
F.min('posting_date').alias('posting_date')
).withColumns({ dt:F.when(F.col('document_type') == dt, F.col('amount')).otherwise(0) for dt in ['RE', 'WE', 'WL']})
)
final_df.show()
+------------+-------+-------------+------+------------+---+---+---+
|company_code|line_no|document_type|amount|posting_date| RE| WE| WL|
+------------+-------+-------------+------+------------+---+---+---+
| AB| 10| RE| 25| 2019-01-01| 25| 0| 0|
| BC| 10| WL| 11| 2019-02-12| 0| 0| 11|
| AB| 20| WE| 14| 2019-01-11| 0| 14| 0|
| BC| 20| RE| 15| 2019-01-21| 15| 0| 0|
+------------+-------+-------------+------+------------+---+---+---+
Upvotes: 0