anderici
anderici

Reputation: 77

InceptionV3 transfer learning with Keras overfitting too soon

I'm using a pre trained InceptionV3 on Keras to retrain the model to make a binary image classification (data labeled with 0's and 1's).

I'm reaching about 65% of accuracy on my k-fold validation with never seen data, but the problem is the model is overfitting to soon. I need to improve this average accuracy, and I guess there is something related to this overfitting problem.

Here are the loss values on epochs: enter image description here

Here is the code. The dataset and label variables are Numpy Arrays.

dataset = joblib.load(path_to_dataset)
labels = joblib.load(path_to_labels)

le = LabelEncoder()
labels = le.fit_transform(labels)
labels = to_categorical(labels, 2)

X_train, X_test, y_train, y_test = sk.train_test_split(dataset, labels, test_size=0.2)
X_train, X_val, y_train, y_val = sk.train_test_split(X_train, y_train, test_size=0.25) # 0.25 x 0.8 = 0.2

X_train = np.array(X_train)
y_train = np.array(y_train)
X_val = np.array(X_val)
y_val = np.array(y_val)
X_test = np.array(X_test)
y_test = np.array(y_test)

aug = ImageDataGenerator(
        rotation_range=20,
        zoom_range=0.15,
        horizontal_flip=True,
        fill_mode="nearest")

pre_trained_model = InceptionV3(input_shape = (299, 299, 3),
                                  include_top = False,
                                  weights = 'imagenet')

for layer in pre_trained_model.layers:
    layer.trainable = False

x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation = 'relu')(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(2, activation = 'softmax')(x) #already tried with sigmoid activation, same behavior

model = Model(pre_trained_model.input, x)
model.compile(optimizer = RMSprop(lr = 0.0001),
                loss = 'binary_crossentropy',
                metrics = ['accuracy']) #Already tried with Adam optimizer, same behavior

es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=100)
mc = ModelCheckpoint('best_model_inception_rmsprop.h5', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)

history = model.fit(x=aug.flow(X_train, y_train, batch_size=32),
                      validation_data = (X_val, y_val),
                      epochs = 100,
                      callbacks=[es, mc])

The training dataset has 2181 images and validation has 727 images.

Something is wrong, but I can't tell what...

Any thoughts of what can be done to improve it?

Upvotes: 0

Views: 1797

Answers (3)

Amin Khodamoradi
Amin Khodamoradi

Reputation: 414

Have you tried following?

  1. Using a higher dropout value
  2. Lower Learning Rate (lr=0.00001 or lr=0.000001 ...)
  3. More data augmentation you can use.
  4. It seems to me your data amount is low. You may use a lower ratio for test and validation (10%, 10%).

Upvotes: 0

Prajot Kuvalekar
Prajot Kuvalekar

Reputation: 6618

From your loss graph , i see that the model is generalized at early epoch ( where there is intersection of both the train & val score) so plz try to use the model saved at that epoch ( and not the later epochs which seems to overfit)
Second option what you have is use lot of training samples..
If you have less no. of training samples then use data augmentations

Upvotes: 1

yakhyo
yakhyo

Reputation: 1656

One way to avoid overfitting is to use a lot of data. The main reason overfitting happens is because you have a small dataset and you try to learn from it. The algorithm will have greater control over this small dataset and it will make sure it satisfies all the datapoints exactly. But if you have a large number of datapoints, then the algorithm is forced to generalize and come up with a good model that suits most of the points. Suggestions:

  1. Use a lot of data.
  2. Use less deep network if you have a small number of data samples.
  3. If 2nd satisfies then don't use huge number of epochs - Using many epochs leads is kinda forcing your model to learn that and your model will learn it well but can not generalize.

Upvotes: 1

Related Questions