Reputation: 1
I have a very imbalanced dataset I imported using image_dataset_from_directory, after training a simple convolutional neural network on it but as expected it was very good on the training data but it wasn't good when it came to predicting testing data samples. I wanted to compare the results with training the same model after oversampling the minority classes in the training data, but after the oversampling the model seems unable to train properly, as the loss is all over the place after each epoch.
This is the oversampling code
# Unbatch the training dataset
training_ds = training_ds.unbatch()
# Extract the images and the labels from the dataset where they are stored in a tuple
X, y = [], []
for image, label in training_ds:
# Flatten the images to obtain 1D arrays
X.append(image.numpy().flatten())
y.append(label.numpy())
X = np.array(X)
y = np.array(y)
# Initialize the oversampler from imblearn module and resample the dataset
oversampler = RandomOverSampler()
X_resampled, y_resampled = oversampler.fit_resample(X,y)
# Give the images the original size back
X_resampled = X_resampled.reshape((-1, 256, 256, 1))
# Merge the images and samples back into a dataset object we can use as input for the model
oversampled_ds_images = tf.data.Dataset.from_tensor_slices(X_resampled)
oversampled_ds_labels = tf.data.Dataset.from_tensor_slices(y_resampled)
training_ds_over = tf.data.Dataset.zip((oversampled_ds_images, oversampled_ds_labels))
# Rebatch the dataset again with the original batch size
training_ds = training_ds.batch(BATCH_SIZE)
training_ds_over = training_ds_over.batch(BATCH_SIZE)
And this is the model compiling and fitting:
cnn_over = create_simple_cnn()
opt = Adam(learning_rate=0.001)
cnn_over.compile(optimizer = opt,
loss = 'categorical_crossentropy',
metrics = METRICS)
early_stop = EarlyStopping(monitor = 'val_accuracy', mode = 'max', patience=5)
history_over = cnn_over.fit(
training_ds_over,
epochs = 10,
validation_data = validation_ds,
callbacks = [early_stop]
)
And this is the fitting output where you can see the loss and the validation loss are all over the place:
Epoch 1/10
287/287 ━━━━━━━━━━━━━━━━━━━━ 191s 663ms/step - accuracy: 0.5111 - f1: 0.3904 - loss: 1.4161 - val_accuracy: 0.3457 - val_f1: 0.1284 - val_loss: 9.4443
Epoch 2/10
287/287 ━━━━━━━━━━━━━━━━━━━━ 191s 664ms/step - accuracy: 0.5532 - f1: 0.3719 - loss: 2.2751 - val_accuracy: 0.3457 - val_f1: 0.1284 - val_loss: 6.6366
Epoch 3/10
287/287 ━━━━━━━━━━━━━━━━━━━━ 187s 653ms/step - accuracy: 0.5296 - f1: 0.3410 - loss: 1.6519 - val_accuracy: 0.3457 - val_f1: 0.1284 - val_loss: 1.3788
Epoch 4/10
287/287 ━━━━━━━━━━━━━━━━━━━━ 192s 668ms/step - accuracy: 0.4525 - f1: 0.2955 - loss: 1.2141 - val_accuracy: 0.3457 - val_f1: 0.1284 - val_loss: 357.0558
Epoch 5/10
287/287 ━━━━━━━━━━━━━━━━━━━━ 194s 676ms/step - accuracy: 0.4466 - f1: 0.3244 - loss: 21.6220 - val_accuracy: 0.3457 - val_f1: 0.1284 - val_loss: 11.5818
Epoch 6/10
287/287 ━━━━━━━━━━━━━━━━━━━━ 191s 666ms/step - accuracy: 0.2812 - f1: 0.2534 - loss: 1.9887 - val_accuracy: 0.3457 - val_f1: 0.1284 - val_loss: 20.8387
Upvotes: 0
Views: 22