Paul
Paul

Reputation: 3361

How does Spark execute a join + filter? Is it scalable?

Say I have two large RDD's, A and B, containing key-value pairs. I want to join A and B using the key, but of the pairs (a,b) that match, I only want a tiny fraction of "good" ones. So I do the join and apply a filter afterwards:

A.join(B).filter(isGoodPair)

where isGoodPair is a boolean function that tells me if a pair (a,b) is good or not.

For this to scale well, Spark's scheduler would ideally avoid forming all pairs in A.join(B) explicitly. Even on a massively distributed basis, this could cause time-consuming disk spills, or even exhaust all memory and disk resources on some nodes. To avoid this, Spark should apply the filter as the pairs (a,b) are generated within each partition.

My questions:

  1. Does Spark actually do this?
  2. What aspects of its architecture enable or prevent the desired behavior?
  3. Should I maybe use cogroup instead? In PySpark it returns an iterator, so I can just apply my filter to the iterator, right?

Upvotes: 3

Views: 4073

Answers (2)

Paul
Paul

Reputation: 3361

I ran an experiment in the PySpark shell (running Spark 1.2.1) to answer these questions. The conclusions are the following:

  1. Unfortunately, Spark does not apply the filter as pairs are generated by the join. It generates the entire set of join pairs explicitly before proceeding to filter them.
  2. This is probably because Spark runs RDD transformations one-at-a-time. It is generally not capable of performing this kind of subtle chaining optimization.
  3. By using cogroup instead of join, we can manually implement the desired optimization.

Experiment

I made an RDD containing 100 groups, each containing the integers 1 to 10,000, and in each group I counted the number of integers that are at most 1 apart:

import itertools as it
g = int(1e2) # number of groups
n = int(1e4) # number of integers in each group
nPart = 32 # standard partitioning: 8 cores, 4 partitions per core
A = sc.parallelize(list(it.product(xrange(g),xrange(n))),nPart) 

def joinAndFilter(A):
    return A.join(A).filter(lambda (k,(x1,x2)): abs(x1 - x2) <= 1)

def cogroupAndFilter(A):
    def fun(xs):
        k,(xs1,xs2) = xs
        return [(x1,x2) for (x1,x2) in it.product(xs1,xs2) if abs(x1 - x2) <= 1]
    return A.cogroup(A).flatMap(fun)

cogroupAndFilter(A).count()
joinAndFilter(A).count() 

I didn't have an easy way to profile the code, so I just watched it run on my mac in Activity Monitor:

Memory usage spiked big-time when I used joinAndFilter, presumably because it's generating all the pairs before applying the off-by-one filter. I actually had to kill PySpark because it was blowing through all my memory and about to crash the system. With cogroupAndFilter, the pairs are filtered as they are generated, so memory stays under control.

Upvotes: 3

Mr. Llama
Mr. Llama

Reputation: 20919

From what I can find, Spark will not entirely buffer the data between the join and filter.

Both the join and filter output DStreams which "[represent] a continuous stream of data". This means that the join should be outputting a continuous stream of data which the filter consumes as it becomes available.

However, from what I can tell, join will generate all A,B pairs with matching keys but the filter will quickly throw away unwanted results preventing the entire result set from being in memory at once.

Upvotes: 1

Related Questions