nick_v1
nick_v1

Reputation: 1664

Getting Top N items per group in pySpark

I am using Spark 1.6.2, I have the following data structure:

sample = sqlContext.createDataFrame([
                    (1,['potato','orange','orange']),
                    (1,['potato','orange','yogurt']),
                    (2,['vodka','beer','vodka']),
                    (2,['vodka','beer','juice', 'vinegar'])

    ],['cat','terms'])

I would like to extract top N most frequent terms per cat. I have developed the following solution which seems to work, however I wanted to see if there is a better way to do this.

from collections import Counter
def get_top(it, terms=200):
    c = Counter(it.__iter__())
    return [x[0][1] for x in c.most_common(terms)]

( sample.select('cat',sf.explode('terms')).rdd.map(lambda x: (x.cat, x.col))
 .groupBy(lambda x: x[0])
 .map(lambda x: (x[0], get_top(x[1], 2)))
 .collect()
)

It provides the following output:

[(1, ['orange', 'potato']), (2, ['vodka', 'beer'])]

Which is in line with what I am looking for, but I really don't like the fact that I am resorting to using Counter. How can I do it with spark alone?

Thanks

Upvotes: 0

Views: 1439

Answers (1)

AChampion
AChampion

Reputation: 30258

If this is working it is probably better to post this to Code Review.

Just as an exercise I did this without the Counter but largely you are just replicating the same functionality.

  • Count each occurrence of (cat, term)
  • Group by cat
  • Sort the values based on Count and slice to number of terms (2)

Code:

from operator import add

(sample.select('cat', sf.explode('terms'))
 .rdd
 .map(lambda x: (x, 1))
 .reduceByKey(add)
 .groupBy(lambda x: x[0][0])
 .mapValues(lambda x: [r[1] for r, _ in sorted(x, key=lambda a: -a[1])[:2]])
 .collect())

Output:

[(1, ['orange', 'potato']), (2, ['vodka', 'beer'])]

Upvotes: 2

Related Questions