Uri Goren
Uri Goren

Reputation: 13700

Limit the amount of records in a spark context

I would like to reduce the number of records for each reducer, and keep the resulting variable a rdd

Using takeSample seemed like the obvious choice, however, it returns a collection and not a SparkContext object.

I came up with this method:

rdd = rdd.zipWithIndex().filter(lambda x:x[1]<limit).map(lambda x:x[0])

However, this method is very slow and not efficient.

Is there a smarter way to take a small sample and keep the data structure an rdd ?

Upvotes: 2

Views: 10236

Answers (1)

zero323
zero323

Reputation: 330413

It you want a small example subset and cannot make any additional assumptions about the data then take combined with parallelize can be an optimal solution:

sc.parallelize(rdd.take(n))

It will touch a relatively low number of partitions (only one in the best case scenario) and the cost of network traffic for small n should be negligible.

Sampling (randomSplit or sample) will require a full data scan same as zipWithIndex with filter.

Assuming there is no data skew you can try something like this to address that:

from __future__ import division  # Python 2 only

def limitApprox(rdd, n, timeout):
    count = rdd.countApprox(timeout)
    if count <= n:
        return rdd
    else:
        rec_per_part = count // rdd.getNumPartitions()
        required_parts = n / rec_per_part if rec_per_part else 1
        return rdd.mapPartitionsWithIndex(
            lambda i, iter: iter if i < required_parts else []
        )
  • this will still access every partition but will try to avoid computing content if not necessary
  • won't work if there is large data skew
    • can take much more than required if distribution is uniform but n << than an average number of records per partition.
    • may undersample if distribution is skewed towards high indices.

If data is representable as a Row you can try another trick:

rdd.toDF().limit(n).rdd

Upvotes: 5

Related Questions