Mpizos Dimitris
Mpizos Dimitris

Reputation: 5001

taking N values from each partition in Spark

Assuming I am having the following data:

val DataSort = Seq(("a",5),("b",13),("b",2),("b",1),("c",4),("a",1),("b",15),("c",3),("c",1))
val DataSortRDD = sc.parallelize(DataSort,2)

And now there are two partitions with:

scala>DataSortRDD.glom().take(2).head
res53: Array[(String,Int)] = Array(("a",5),("b",13),("b",2),("b",1),("c",4))
scala>DataSortRDD.glom().take(2).tail
res54: Array[(String,Int)] = Array(Array(("a",1),("b",15),("c",3),("c",2),("c",1)))

It is assumed that in every partition the data is already sorted using something like sortWithinPartitions(col("src").desc,col("rank").desc)(thats for a dataframe but is just to illustrate).

What I want is from each partition get for each letter the first two values(if there are more than 2 values). So in this example the result in each partition should be:

scala>HypotheticalRDD.glom().take(2).head
Array(("a",5),("b",13),("b",2),("c",4))
scala>HypotheticalRDD.glom().take(2).tail
Array(Array(("a",1),("b",15),("c",3),("c",2)))

I Know that I have to use the mapPartition function but its not clear in my mind how can I iterate through the values in each partition and get the first 2. Any tip?

Edit: More precisely. I know that in each partition the data is already sorted by 'letter' first and after by 'count'. So my main idea is that the input function in mapPartition should iterate through the partition and yield the first two values of each letter. And this could be done by checking every iterate the .next() value. This is how I could write it in python:

def limit_on_sorted(iterator):
    oldKey = None
    cnt = 0
    while True:
        elem = iterator.next()
        if not elem:
            return
        curKey = elem[0]
        if curKey == oldKey:
            cnt +=1
            if cnt >= 2:
                yield None
        else:
            oldKey = curKey
            cnt = 0
        yield elem

DataSortRDDpython.mapPartitions(limit_on_sorted,preservesPartitioning=True).filter(lambda x:x!=None)

Upvotes: 0

Views: 1986

Answers (1)

Tzach Zohar
Tzach Zohar

Reputation: 37852

Assuming you don't really care about the partitioning of the result, you can use mapPartitionsWithIndex to incorporate the partition ID into the key by which you groupBy, then you can easily take the first two items for each such key:

val result: RDD[(String, Int)] = DataSortRDD
  .mapPartitionsWithIndex {
     // add the partition ID into the "key" of every record:
     case (partitionId, itr) => itr.map { case (k, v) => ((k, partitionId), v) }
   }
  .groupByKey() // groups by letter and partition id
  // take only first two records, and drop partition id
  .flatMap { case ((k, _), itr) => itr.take(2).toArray.map((k, _)) }

println(result.collect().toList)
// prints:
// List((a,5), (b,15), (b,13), (b,2), (a,1), (c,4), (c,3))

Do note that the end result is not partitioned in the same way (groupByKey changes the partitioning), I'm assuming this isn't critical to what you're trying to do (which, frankly, escapes me).

EDIT: if you want to avoid shuffling and perform all operations within each partition:

val result: RDD[(String, Int)] = DataSortRDD
  .mapPartitions(_.toList.groupBy(_._1).mapValues(_.take(2)).values.flatten.iterator, true)

Upvotes: 1

Related Questions