Reputation: 20113
Consider this example:
import pyspark
import pyspark.sql.functions as f
with pyspark.SparkContext(conf=pyspark.SparkConf().setMaster('local[*]')) as sc:
spark = pyspark.sql.SQLContext(sc)
df = spark.createDataFrame([
[2020, 1, 1, 1.0],
[2020, 1, 2, 2.0],
[2020, 1, 3, 3.0],
], schema=['year', 'id', 't', 'value'])
df = df.groupBy(['year', 'id']).agg(f.collect_list('value'))
df = df.where(f.col('year') == 2020)
df.explain()
which yields the following plan
== Physical Plan ==
*(2) Filter (isnotnull(year#0L) AND (year#0L = 2020))
+- ObjectHashAggregate(keys=[year#0L, id#1L], functions=[collect_list(value#3, 0, 0)])
+- Exchange hashpartitioning(year#0L, id#1L, 200), true, [id=#23]
+- ObjectHashAggregate(keys=[year#0L, id#1L], functions=[partial_collect_list(value#3, 0, 0)])
+- *(1) Project [year#0L, id#1L, value#3]
+- *(1) Scan ExistingRDD[year#0L,id#1L,t#2L,value#3]
I would like Spark to push the filter year = 2020
to before the hashpartitioning
. If the aggregation function is sum
, Spark does it, but it does not do it for collect_list
.
Any ideas as to why this is not the case, and whether there is a way to address this?
The reason for doing this is that without a filter pushdown, the statement for 3 years (e.g. year IN (2020, 2019, 2018)
performs a shuffle between them. Also, I need to express the filter after the groupBy in code.
More importantly, I am trying to understand why Spark does not push the filter down for some aggregations, but it does for others.
Upvotes: 5
Views: 1009
Reputation: 6323
Let's have a look at the aggregate function that you are using.
collect_list
From the doc below -
/**
* Aggregate function: returns a list of objects with duplicates.
*
* @note The function is non-deterministic because the order of collected results depends
* on the order of the rows which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 1.6.0
*/
def collect_list(columnName: String): Column = collect_list(Column(columnName))
collect_list is a non-deterministic operation and its result depends on the order of rows.
Now look at the Optimizer.scala#PushPredicateThroughNonJoin,
// SPARK-13473: We can't push the predicate down when the underlying projection output non-
// deterministic field(s). Non-deterministic expressions are essentially stateful. This
// implies that, for a given input row, the output are determined by the expression's initial
// state and all the input rows processed before. In another word, the order of input rows
// matters for non-deterministic expressions, while pushing down predicates changes the order.
// This also applies to Aggregate.
Since the above operation is non-deterministic i.e. the result is dependent on the order of rows of underlying dataframe, spark can't push the predicate because it changes the order of rows.
Upvotes: 4