Bilgin
Bilgin

Reputation: 519

How to select the subset of data from each category using for loop in Python?

I have customer data (in CSV format) as:

index category  text
 0    spam      you win much money
 1    spam      you win 7000 car
 2    not_spam  the weather in Chicago is nice
 3    neutral   we have a party now
 4    neutral   they are driving to downtown
 5    not_spam  pizza is an Italian food

As an example, each category contains various count:

customer.category.value_counts():    

spam       100
not_spam   20
neutral    45

where:

min(customer.category.value_counts()): 20

I want to write a for loop in python to create a new data file that for all category only contains the same size equal to the smallest category count (in the example here smallest category is not_spam).

My expected output would be:

new_customer.category.value_counts():    

spam       20
not_spam   20
neutral    20

Upvotes: 1

Views: 436

Answers (3)

karas27
karas27

Reputation: 365

This should work. it keeps concatenating generated top min records from each category

minval = min(df1.category.value_counts())
df2 = pd.concat([df1[df1.category == cat].head(minval) for cat in df1.category.unique() ])
print(df2)

Upvotes: 1

Hryhorii Pavlenko
Hryhorii Pavlenko

Reputation: 3910

My randomly generated dataframe has 38 rows with the following distribution of categories:

spam        17
not_spam    16
neutral      5
Name: category, dtype: int64

I was thinking that the first thing you need to do is to find the smallest category, and once you know that, you could .sample() each category using calculated value as n:

def sample(df: pd.DataFrame, category: pd.Series):
    threshold = df[category].value_counts().min()
    for cat in df[category].unique():
        data = df.loc[df[category].eq(cat)]
        yield data.sample(threshold)


data = sample(df, "category")
pd.concat(data, ignore_index=True)


    text    category
0   v   not_spam
1   l   not_spam
2   q   not_spam
3   j   not_spam
4   f   not_spam
5   l   spam
6   t   spam
7   r   spam
8   n   spam
9   k   spam
10  n   neutral
11  n   neutral
12  d   neutral
13  q   neutral
14  l   neutral

Upvotes: 1

It's easier to use groupby:

min_count = df.category.value_counts().min()
df.groupby('category').head(min_count)

That said, if you really want a loop, you can use it as a list comprehension which is faster:

categories = df.category.unique()
min_count = df.category.value_counts().min()
df = pd.concat([df.query('category==@cat')[:min_count] for cat in categories])

Upvotes: 3

Related Questions