Salmon
Salmon

Reputation: 381

Use SMOTE to oversample image data

I'm doing a binary classification with CNNs and the data is imbalanced where the positive medical image : negative medical image = 0.4 : 0.6. So I want to use SMOTE to oversample the positive medical image data before training. However, the dimension of the data is 4D (761,64,64,3) which cause the error

Found array with dim 4. Estimator expected <= 2

So, I reshape my train_data:

X_res, y_res = smote.fit_sample(X_train.reshape(X_train.shape[0], -1), y_train.ravel())

And it works fine. Before feed it to CNNs, I reshape it back by:

X_res = X_res.reshape(X_res.shape[0], 64, 64, 3)

Now, I'm not sure is it a correct way to oversample and will the reshape operator change the images' structer?

Upvotes: 10

Views: 15942

Answers (3)

Hemanth Kollipara
Hemanth Kollipara

Reputation: 1141

  • First Flatten the image
  • Apply SMOTE on this flattened image data and its labels
  • Reshape the flattened image to RGB image
from imblearn.over_sampling import SMOTE
    
sm = SMOTE(random_state=42)
    
train_rows=len(X_train)
X_train = X_train.reshape(train_rows,-1)
(80,30000)

X_train, y_train = sm.fit_resample(X_train, y_train)
X_train = X_train.reshape(-1,100,100,3)
(>80,100,100,3)

Upvotes: 2

cerofrais
cerofrais

Reputation: 1327

  1. As soon as you flatten an image you are loosing localized information, this is one of the reasons why convolutions are used in image-based machine learning.
  2. 8000x250x250x3 has an inherent meaning - 8000 samples of images, each image of width 250, height 250 and all of them have 3 channels when you do 8000x250*250*3 reshape is just a bunch of numbers unless you use some kind of sequence network to teach its bad.
  3. oversampling is bad for image data, you can do image augmentations (20crop, introducing noise like a gaussian blur, rotations, translations, etc..)

Upvotes: 3

Aditya Bhattacharya
Aditya Bhattacharya

Reputation: 1014

I had a similar issue. I had used the reshape function to reshape the image (basically flattened the image)

X_train.shape
(8000, 250, 250, 3)

ReX_train = X_train.reshape(8000, 250 * 250 * 3)
ReX_train.shape
(8000, 187500)

smt = SMOTE()
Xs_train, ys_train = smt.fit_sample(ReX_train, y_train)

Although, this approach is pathetically slow, but helped to improve the performance.

Upvotes: 7

Related Questions