Reputation: 846
So I have a dataframe in Spark with the following data:
user_id item category score
-------------------------------
user_1 item1 categoryA 8
user_1 item2 categoryA 7
user_1 item3 categoryA 6
user_1 item4 categoryD 5
user_1 item5 categoryD 4
user_2 item6 categoryB 7
user_2 item7 categoryB 7
user_2 item8 categoryB 7
user_2 item9 categoryA 4
user_2 item10 categoryE 2
user_2 item11 categoryE 1
And I want to filter it so I keep only 2 items of the same category for each user based on a score, so it should look like:
user_id item category score
-------------------------------
user_1 item1 categoryA 8
user_1 item2 categoryA 7
user_1 item4 categoryD 5
user_1 item5 categoryD 4
user_2 item7 categoryB 7
user_2 item8 categoryB 7
user_2 item9 categoryA 4
user_2 item10 categoryE 2
user_2 item11 categoryE 1
Where item 3 and 6 were removed due to being a 3rd item of the same category for a user and had the lower score (or if equal score, dropped at random until there are only two items of that category).
I tried to do it with a partition over window like this:
df2 = test.withColumn(
'rank',
F.rank().over(Window.partitionBy('user_id','category').orderBy(F.desc('score')))
).filter('rank <= 2')
That returned me:
user_id item category score
-------------------------------
user_1 item1 categoryA 8
user_1 item2 categoryA 7
user_1 item4 categoryD 5
user_1 item5 categoryD 4
user_2 item6 categoryB 7
user_2 item7 categoryB 7
user_2 item8 categoryB 7
user_2 item9 categoryA 4
user_2 item10 categoryE 2
user_2 item11 categoryE 1
This works with user 1 but it wouldn't work with user 2 since it has items of the same category with equal score, so for items 6,7,8 that belong to categoryB and has score 7 will not filter any out. In this edge case I would like to filter one at random to have only 2.
Any idea on how to do this type of filtering?
Upvotes: 1
Views: 265
Reputation: 42352
You can use row_number
over a window partitioned by user_id and category:
df2 = df.withColumn(
'rank',
F.row_number().over(Window.partitionBy('user_id', 'category').orderBy(F.desc('score')))
).filter('rank <= 2').drop('rank').orderBy('user_id', 'item', 'category', F.desc('score'))
df2.show()
+-------+------+---------+-----+
|user_id| item| category|score|
+-------+------+---------+-----+
| user_1| item1|categoryA| 8|
| user_1| item2|categoryA| 7|
| user_1| item4|categoryD| 5|
| user_1| item5|categoryD| 4|
| user_2|item10|categoryE| 2|
| user_2|item11|categoryE| 1|
| user_2| item7|categoryB| 7|
| user_2| item8|categoryB| 5|
| user_2| item9|categoryA| 4|
+-------+------+---------+-----+
Upvotes: 2