eljiwo
eljiwo

Reputation: 846

Filter top N values in column for each value in other column based on score pyspark

So I have a dataframe in Spark with the following data:

user_id item   category score
-------------------------------
user_1  item1  categoryA  8
user_1  item2  categoryA  7
user_1  item3  categoryA  6
user_1  item4  categoryD  5
user_1  item5  categoryD  4
user_2  item6  categoryB  7
user_2  item7  categoryB  7
user_2  item8  categoryB  7
user_2  item9  categoryA  4
user_2  item10 categoryE  2
user_2  item11 categoryE  1

And I want to filter it so I keep only 2 items of the same category for each user based on a score, so it should look like:

user_id item   category score
-------------------------------
user_1  item1  categoryA  8
user_1  item2  categoryA  7
user_1  item4  categoryD  5
user_1  item5  categoryD  4
user_2  item7  categoryB  7
user_2  item8  categoryB  7
user_2  item9  categoryA  4
user_2  item10 categoryE  2
user_2  item11 categoryE  1

Where item 3 and 6 were removed due to being a 3rd item of the same category for a user and had the lower score (or if equal score, dropped at random until there are only two items of that category).

I tried to do it with a partition over window like this:

df2 = test.withColumn(
    'rank', 
    F.rank().over(Window.partitionBy('user_id','category').orderBy(F.desc('score')))
).filter('rank <= 2')

That returned me:

user_id item   category score
-------------------------------
user_1  item1  categoryA  8
user_1  item2  categoryA  7
user_1  item4  categoryD  5
user_1  item5  categoryD  4
user_2  item6  categoryB  7
user_2  item7  categoryB  7
user_2  item8  categoryB  7
user_2  item9  categoryA  4
user_2  item10 categoryE  2
user_2  item11 categoryE  1

This works with user 1 but it wouldn't work with user 2 since it has items of the same category with equal score, so for items 6,7,8 that belong to categoryB and has score 7 will not filter any out. In this edge case I would like to filter one at random to have only 2.

Any idea on how to do this type of filtering?

Upvotes: 1

Views: 265

Answers (1)

mck
mck

Reputation: 42352

You can use row_number over a window partitioned by user_id and category:

df2 = df.withColumn(
    'rank',
    F.row_number().over(Window.partitionBy('user_id', 'category').orderBy(F.desc('score')))
).filter('rank <= 2').drop('rank').orderBy('user_id', 'item', 'category', F.desc('score'))

df2.show()
+-------+------+---------+-----+
|user_id|  item| category|score|
+-------+------+---------+-----+
| user_1| item1|categoryA|    8|
| user_1| item2|categoryA|    7|
| user_1| item4|categoryD|    5|
| user_1| item5|categoryD|    4|
| user_2|item10|categoryE|    2|
| user_2|item11|categoryE|    1|
| user_2| item7|categoryB|    7|
| user_2| item8|categoryB|    5|
| user_2| item9|categoryA|    4|
+-------+------+---------+-----+

Upvotes: 2

Related Questions