Reputation: 1
I have a bird classification using CNN as below. However, it performs quite good on training set (50%) and really bad on validation set (<1%).
Here is the data:
train_images.shape
(4829, 100, 100, 3)
train_labels.shape
(4829, 200)
test_images.shape
(1204, 100, 100, 3)
test_labels.shape
(1204, 200)
For the model, I used transfer learning from MobileNet and fine-tune it.
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import efficientnet_v2
from tensorflow.keras.optimizers import Adam
# Load the pre-trained EfficientNetV2B0 model
base_model = tf.keras.applications.MobileNetV2(input_shape=(100,100,3),
include_top=False,
weights='imagenet')
base_model.trainable = True
# Fine-tune from this layer onwards
fine_tune_at = 100
# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
To prevent overfitting, I used data augmentation, L2 regularization, dropout, early stopping and learning rate decay but they seem to be not effective.
inputs = tf.keras.Input(shape=(100, 100, 3))
x = augment(inputs, training=True)
x = base_model(x, training=True)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(580, activation='relu', kernel_regularizer=regularizers.l2(0.01), bias_regularizer=regularizers.l2(0.01))(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(200, activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(loss = "categorical_crossentropy",
optimizer = tf.keras.optimizers.Adam(),
metrics = ['accuracy'])
from tensorflow.keras.callbacks import Callback, EarlyStopping,ModelCheckpoint, ReduceLROnPlateau
# Setup EarlyStopping callback to stop training if model's val_loss doesn't improve for 3 epochs
early_stopping = EarlyStopping(monitor = "val_loss", # watch the val loss metric
patience = 5,
restore_best_weights = True) # if val loss decreases for 3 epochs in a row, stop training
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6)
history = model.fit(train_images, train_labels, epochs=50,
validation_split=0.2,
callbacks=[
early_stopping,
reduce_lr
])
Here is the performance: enter image description here enter image description here
Can someone suggest me solutions for this problem please?
Upvotes: -2
Views: 67
Reputation: 1
Improve the training data, Reduce the model complexity, Reduce the number of features
Upvotes: 0