Reputation: 161
I have a DataFrame like this:
file_name label
../input/image-classification-screening/train/... 1
../input/image-classification-screening/train/... 7
../input/image-classification-screening/train/... 9
../input/image-classification-screening/train/... 9
../input/image-classification-screening/train/... 6
And it has 11 classes (0 to 10) and has high class imbalance. Below is the output of train['label'].value_counts()
:
6 6285
3 4139
9 3933
7 3664
2 2778
5 2433
8 2338
0 2166
4 2052
10 1039
1 922
How do I under-sample this data in pandas so that each class will have below 2500 examples? I want to remove data points randomly from majority classes like 6, 3, 9, 7 and 2.
Upvotes: 1
Views: 4740
Reputation:
You could create a mask that identifies which "label"s have more than 2500 items and then use groupby
+sample
(by setting n=n
to sample the required number of items) on the ones with more than 2500 items and select all of the labels with less than 2500 items. This creates two DataFrames, one sampled to 2500, and the other selected in whole. Then concatenate the two groups using pd.concat
:
n = 2500
msk = df.groupby('label')['label'].transform('size') >= n
df = pd.concat((df[msk].groupby('label').sample(n=n), df[~msk]), ignore_index=True)
For example, if you had a DataFrame like:
df = pd.DataFrame({'ID': range(30),
'label': ['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B', 'B', 'B',
'B', 'B', 'B', 'B', 'C', 'C', 'D', 'F', 'F', 'G']})
and
>>> df['label'].value_counts()
A 13
B 11
C 2
F 2
D 1
G 1
Name: label, dtype: int64
Then the above code with n=3
, yields:
ID label
0 7 A
1 0 A
2 10 A
3 20 B
4 18 B
5 21 B
6 24 C
7 25 C
8 26 D
9 27 F
10 28 F
11 29 G
with
>>> df['label'].value_counts()
A 3
B 3
C 2
F 2
D 1
G 1
Name: label, dtype: int64
Upvotes: 2
Reputation: 29635
You can use sample
in a groupby.apply
. Here s a reproductible example with 4 unbalanced labels.
np.random.seed(1)
df = pd.DataFrame({
'a':range(100),
'label':np.random.choice(range(4), size=100, p=[0.5,0.3,0.18,0.02])})
print(df['label'].value_counts())
# 0 51
# 1 30
# 2 18
# 3 1
# Name: label, dtype: int64
Now to select maximum 25 maximum (replace by 2500 for you) per label, you do:
nMax = 25 #change to 2500
res = df.groupby('label').apply(lambda x: x.sample(n=min(nMax, len(x))))
print(res['label'].value_counts())
# 0 25 # see how label 0 and 1 are now 25
# 1 25
# 2 18 # and the smaller groups stay the same
# 3 1
# Name: label, dtype: int64
Upvotes: 3