Reputation: 5077
I am training with tensorflow2.0 and tensorflow_datasets. But I am not understand: why does the training accuracy and loss and valdataion accuracy and loss are different?
This is my code:
import tensorflow as tf
import tensorflow_datasets as tfds
data_name = 'uc_merced'
dataset = tfds.load(data_name)
# the train_data and the test_data are same dataset
train_data, test_data = dataset['train'], dataset['train']
def parse(img_dict):
img = tf.image.resize_with_pad(img_dict['image'], 256, 256)
#img = img / 255.
label = img_dict['label']
return img, label
train_data = train_data.map(parse)
train_data = train_data.batch(96)
test_data = test_data.map(parse)
test_data = test_data.batch(96)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = tf.keras.applications.ResNet50(weights=None, classes=21,
input_shape=(256, 256, 3))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_data, epochs=50, verbose=2, validation_data=test_data)
It is very simple and you can run it on your computer. you can see my train data and validation data are the same train_data, test_data = dataset['train'], dataset['train']
.
But the train accuracy (loss) are not the same with validation accuracy (loss). Why is it happen? Is this the bug of tensorflow2.0?
Epoch 1/50
22/22 - 51s - loss: 3.3766 - accuracy: 0.2581 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/50
22/22 - 30s - loss: 1.8221 - accuracy: 0.4590 - val_loss: 123071.9851 - val_accuracy: 0.0476
Epoch 3/50
22/22 - 30s - loss: 1.4701 - accuracy: 0.5405 - val_loss: 12767.8928 - val_accuracy: 0.0519
Epoch 4/50
22/22 - 30s - loss: 1.2113 - accuracy: 0.6071 - val_loss: 3.9311 - val_accuracy: 0.1186
Epoch 5/50
22/22 - 31s - loss: 1.0846 - accuracy: 0.6567 - val_loss: 23.7775 - val_accuracy: 0.1386
Epoch 6/50
22/22 - 31s - loss: 0.9358 - accuracy: 0.7043 - val_loss: 15.3453 - val_accuracy: 0.1543
Epoch 7/50
22/22 - 32s - loss: 0.8566 - accuracy: 0.7243 - val_loss: 8.0415 - val_accuracy: 0.2548
Upvotes: 3
Views: 3416
Reputation: 1515
In short, the culprit here is BatchNorm.
Since you have a small dataset and large batch size, you only do 22 updates per epoch. The BatchNorm layer has a default momentum of 0.99, so it takes some time to move the BatchNorm running means/variances to values more appropriate for your dataset (which, given you do not normalise the pixel values away from the [0, 255] range, is pretty far from the typical mean=0, variance=1
sort of range that neural networks are generally designed/initialised to expect).
The reason for the big discrepancy in train vs. validation loss/accuracy is because the training behaviour of batch norm versus the testing behaviour is very different, especially with so few batches. The mean of the data running through the network during training is very far from the running mean accumulated so far, which only updates slowly due to the default BatchNorm momentum/decay of 0.99.
If you reduce your batch size from 96 to, say, 4, you substantially increase the frequency of updates to the BatchNorm running means/variances. Doing this, plus uncommenting the #img = img / 255.
line in your data parsing function, alleviates the train/validation discrepancy to a large extent. Doing so gives me this output for three epochs:
Epoch 1/7
525/525 - 51s - loss: 3.2650 - accuracy: 0.1633 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/7
525/525 - 38s - loss: 2.6455 - accuracy: 0.2152 - val_loss: 12.1067 - val_accuracy: 0.2114
Epoch 3/7
525/525 - 38s - loss: 2.5033 - accuracy: 0.2414 - val_loss: 16.9369 - val_accuracy: 0.2095
You can also keep your code the same, and instead modify the keras_applications
implementation of Resnet50 to use BatchNormalization(..., momentum=0.9)
everywhere. This gives me the following output after two epochs, which I think more or less shows that indeed this is the main cause of your issue:
Epoch 1/2
22/22 [==============================] - 33s 1s/step - loss: 3.1512 - accuracy: 0.2357 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/2
22/22 [==============================] - 16s 748ms/step - loss: 1.7975 - accuracy: 0.4505 - val_loss: 4.1324 - val_accuracy: 0.2810
Upvotes: 6