Paul
Paul

Reputation: 3361

Does Spark SQL do predicate pushdown on filtered equi-joins?

I am interested in using Spark SQL (1.6) to perform "filtered equi-joins" of the form

A inner join B where A.group_id = B.group_id and pair_filter_udf(A[cols], B[cols])

Here the group_id is coarse: a single value of group_id could be associated with, say, 10,000 records in both A and B.

If the equi-join were performed by itself, without the pair_filter_udf, the coarseness of group_id would create computational issues. For example, for a group_id with 10,000 records in both A and B, there would be 100 million entries in the join. If we had many thousands of such large groups, we would generate an enormous table and we could very easily run out of memory.

Thus, it is essential that we push down pair_filter_udf into the join and have it filter pairs as they are generated, rather than waiting until all pairs have been generated. My question is whether Spark SQL does this.

I set up a simple filtered equi-join and asked Spark what its query plan was:

# run in PySpark Shell
import pyspark.sql.functions as F

sq = sqlContext
n=100
g=10
a = sq.range(n)
a = a.withColumn('grp',F.floor(a['id']/g)*g)
a = a.withColumnRenamed('id','id_a')

b = sq.range(n)
b = b.withColumn('grp',F.floor(b['id']/g)*g)
b = b.withColumnRenamed('id','id_b')

c = a.join(b,(a.grp == b.grp) & (F.abs(a['id_a'] - b['id_b']) < 2)).drop(b['grp'])
c = c.sort('id_a')
c = c[['grp','id_a','id_b']]
c.explain()

Result:

== Physical Plan ==
Sort [id_a#21L ASC], true, 0
+- ConvertToUnsafe
   +- Exchange rangepartitioning(id_a#21L ASC,200), None
      +- ConvertToSafe
         +- Project [grp#20L,id_a#21L,id_b#24L]
            +- Filter (abs((id_a#21L - id_b#24L)) < 2)
               +- SortMergeJoin [grp#20L], [grp#23L]
                  :- Sort [grp#20L ASC], false, 0
                  :  +- TungstenExchange hashpartitioning(grp#20L,200), None
                  :     +- Project [id#19L AS id_a#21L,(FLOOR((cast(id#19L as double) / 10.0)) * 10) AS grp#20L]
                  :        +- Scan ExistingRDD[id#19L] 
                  +- Sort [grp#23L ASC], false, 0
                     +- TungstenExchange hashpartitioning(grp#23L,200), None
                        +- Project [id#22L AS id_b#24L,(FLOOR((cast(id#22L as double) / 10.0)) * 10) AS grp#23L]
                           +- Scan ExistingRDD[id#22L]

These are the key lines from the plan:

+- Filter (abs((id_a#21L - id_b#24L)) < 2)
    +- SortMergeJoin [grp#20L], [grp#23L]

These lines gives the impression that the filter will be done in a separate stage after the join, which is not the desired behavior. But maybe it's being implicitly pushed down into the join, and the query plan just lacks that level of detail.

How can I tell what Spark is doing in this case?

Update:

I'm running experiments with n=1e6 and g=1e5, which should be enough to crash my laptop if Spark is not doing pushdown. Since it is not crashing, I guess it is doing pushdown. But it would be interesting to know how it works and what parts of the Spark SQL source are responsible for this awesome optimization.

Upvotes: 5

Views: 3604

Answers (1)

zero323
zero323

Reputation: 330443

Quite a lot depends on what you mean by pushdown. If you ask if |a.id_a - b.id_b| < 2 is executed as a part of a join logic next to a.grp = b.grp the answer is negative. Predicates which are not based on equality are not directly included in the join condition.

One way you can illustrate that is to use DAG instead of execution plan It should look more or less like this:

enter image description here

As you can see filter is executed as a separate transformation from the SortMergeJoin. Another approach is to analyze execution plan when you drop a.grp = b.grp. You'll see that it expands join to a Cartesian product followed by a filter with no additional optimizations:

d = a.join(b,(F.abs(a['id_a'] - b['id_b']) < 2)).drop(b['grp'])

## == Physical Plan ==
## Project [id_a#2L,grp#1L,id_b#5L]
## +- Filter (abs((id_a#2L - id_b#5L)) < 2)
##    +- CartesianProduct
##       :- ConvertToSafe
##       :  +- Project [id#0L AS id_a#2L,(FLOOR((cast(id#0L as double) / 10.0)) * 10) AS grp#1L]
##       :     +- Scan ExistingRDD[id#0L] 
##       +- ConvertToSafe
##          +- Project [id#3L AS id_b#5L]
##             +- Scan ExistingRDD[id#3L]

Does it mean your code (not the one with Cartesian - you really want to avoid this in practice) generates a huge intermediate table?

No, it doesn't. Both SortMergeJoin and filter are executed as a single stage (see DAG). While some details of the DataFrame operations can be applied at a slightly lower level it is basically just a chain of the transformations on the Scala Iterators and, as shown in a very illustrative way by Justin Pihony, different operations can be squashed together without adding any Spark-specific logic. One way or another both filters will be applied in a single task.

Upvotes: 5

Related Questions