Bao Nguyen
Bao Nguyen

Reputation: 1

CNN overfitting and poor performance on validation set

I have a bird classification using CNN as below. However, it performs quite good on training set (50%) and really bad on validation set (<1%).

Here is the data:

train_images.shape

(4829, 100, 100, 3)

train_labels.shape

(4829, 200)

test_images.shape

(1204, 100, 100, 3)

test_labels.shape

(1204, 200)

For the model, I used transfer learning from MobileNet and fine-tune it.

import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import efficientnet_v2
from tensorflow.keras.optimizers import Adam

# Load the pre-trained EfficientNetV2B0 model
base_model = tf.keras.applications.MobileNetV2(input_shape=(100,100,3),
                                               include_top=False,
                                               weights='imagenet')

base_model.trainable = True

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable = False

To prevent overfitting, I used data augmentation, L2 regularization, dropout, early stopping and learning rate decay but they seem to be not effective.

inputs = tf.keras.Input(shape=(100, 100, 3))
x = augment(inputs, training=True)
x = base_model(x, training=True)

x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(580, activation='relu', kernel_regularizer=regularizers.l2(0.01), bias_regularizer=regularizers.l2(0.01))(x)
x = tf.keras.layers.Dropout(0.2)(x)

outputs = tf.keras.layers.Dense(200, activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)

model.compile(loss = "categorical_crossentropy",
               optimizer = tf.keras.optimizers.Adam(),
               metrics = ['accuracy'])

from tensorflow.keras.callbacks import Callback, EarlyStopping,ModelCheckpoint, ReduceLROnPlateau

# Setup EarlyStopping callback to stop training if model's val_loss doesn't improve for 3 epochs
early_stopping = EarlyStopping(monitor = "val_loss", # watch the val loss metric
                               patience = 5,
                               restore_best_weights = True) # if val loss decreases for 3 epochs in a row, stop training

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6)

history = model.fit(train_images, train_labels, epochs=50,
                    validation_split=0.2,
                    callbacks=[
                      early_stopping,
                      reduce_lr
                    ])

Here is the performance: enter image description here enter image description here

Can someone suggest me solutions for this problem please?

Upvotes: -2

Views: 67

Answers (1)

Improve the training data, Reduce the model complexity, Reduce the number of features

Upvotes: 0

Related Questions