Reputation: 13700
I would like to reduce the number of records for each reducer, and keep the resulting variable a rdd
Using takeSample
seemed like the obvious choice, however, it returns a collection
and not a SparkContext
object.
I came up with this method:
rdd = rdd.zipWithIndex().filter(lambda x:x[1]<limit).map(lambda x:x[0])
However, this method is very slow and not efficient.
Is there a smarter way to take a small sample and keep the data structure an rdd
?
Upvotes: 2
Views: 10236
Reputation: 330413
It you want a small example subset and cannot make any additional assumptions about the data then take
combined with parallelize
can be an optimal solution:
sc.parallelize(rdd.take(n))
It will touch a relatively low number of partitions (only one in the best case scenario) and the cost of network traffic for small n should be negligible.
Sampling (randomSplit
or sample
) will require a full data scan same as zipWithIndex
with filter
.
Assuming there is no data skew you can try something like this to address that:
from __future__ import division # Python 2 only
def limitApprox(rdd, n, timeout):
count = rdd.countApprox(timeout)
if count <= n:
return rdd
else:
rec_per_part = count // rdd.getNumPartitions()
required_parts = n / rec_per_part if rec_per_part else 1
return rdd.mapPartitionsWithIndex(
lambda i, iter: iter if i < required_parts else []
)
If data is representable as a Row
you can try another trick:
rdd.toDF().limit(n).rdd
Upvotes: 5