Riley K
Riley K

Reputation: 403

Good training accuracy but poor validation accuracy

I am trying to implement a residual network to classify images on the CIFAR10 dataset for a project and I have a working model that has an accuracy that logarthimically grows, but a validation accuracy that plateaus. I used batch normalization and relu after most layers and used a softmax at the end.

Here is my data split:

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

Here is my code to compile and train the model

resNet50.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
resNet50.fit(train_images, train_labels, epochs=EPOCHS, validation_data=(test_images, test_labels))

enter image description here

enter image description here

What might be causing this validation plateau and what could improve my model?

Thank you in advanced for your feedback and comments.

Upvotes: 1

Views: 1931

Answers (1)

gdelab
gdelab

Reputation: 6220

This is a very common problem, that is a form of overfitting.

I invite you to read the book Deep Learning by Ian Goodfellow and Yoshua Bengio and Aaron Courville, especially this chapter (in free access), that's very informative.

In short, you seem to have chosen a model (ResNet50 + default training parameters) that has too much capacity for your problem and data. If you choose a model that is too simple, you'll get the training and evaluation curves very close to one another, but with worse performance that what you could achieve. If you choose a model that is too complex (as is a bit the case here), you can reach a much better performance on the training data, but the eval will not be at the same level, and could even be quite bad. That's called overfitting on the training set.

What you want is the best middle point : the best performance on evaluation data is found with a model complexity that's just before overfitting : you want the two performance curves to be close one to another, but both should be as good as possible.

So you need to decrease the capacity of your model for your problem. There are different ways to do that, they will not be equally efficient in terms of reducing overfitting, nor in terms of decreasing your train performance. The best method is usually to add more training data, if you can. If you can't, the next good things to add is regularization, such as data augmentation, dropout, L1 or L2 regularization, and early stopping. The last one is especially useful if your validation performance starts decreasing at some point, instead of just plateauing. It's not your case, so it should not be your first track.

If regularization is not enough, then try to play with the learning rate, or the other parameters mentioned in the book. You should be able to make ResNet50 itself work much better than this on Cifar10, but maybe it's not that trivial.

Upvotes: 1

Related Questions