DrakeMurdoch
DrakeMurdoch

Reputation: 859

Drop rows randomly from a dataframe such that no rows exist with a count over N

Given a dataframe df of encoded topics and items that looks like

topic    item
0        bucket
1        fish
2        car
0        pail
2        truck
3        glove 

where there are X topics and Y items, such that if I look at a count of items per topic

print(df.groupby(by='topic').agg('count'))

                 item
topic                
0                8568   
1                7539  
2               48700   
3               26036   
4                4190  
5                2153 
...               ...
X-2               328
X-1              5942
X               15871

How could I make this such that no topic has no more than N associated items? For example, let's say N = 5000. Then, were I do a count I would get

print(df.groupby(by='topic').agg('count'))

                 item
topic                
0                5000   
1                5000  
2                5000   
3                5000   
4                4190  
5                2153 
...               ...
X-2               328
X-1              5000
X                5000

Where everything that has above 5000 counts is reduced to 5000 counts and everything below it is left untouched. The rows dropped also need to be dropped randomly, and not just the first appearing.

psuedocode:

# Randomly drops rows by topic until there are no topics that have a count
# above 5000
df.drop_rows_by_count(
                      based_on='topic'
                      above_below='above', 
                      count=5000,
                      how='random'
)

How would I go about doing this?

Upvotes: 1

Views: 40

Answers (1)

jfaccioni
jfaccioni

Reputation: 7509

Shuffle the entire dataframe first with sample and a frac argument of 1 (which basically "samples" the entire dataframe). The replace=False argument keeps pandas from selecting the same row twice during the sampling process.

Then use head to get the first N occurences (at random order since we just shuffled the rows):

df.sample(frac=1, replace=False).groupby('topic').head(5000)

Upvotes: 6

Related Questions