ivan
ivan

Reputation: 111

Tensorflow CNN regression MSE higher for train than test

I'm feeding a CNN with pictures to predict a value in a regression settings.

Input: [NUM_EXAMPLES, HEIGHT, WIDTH, CHANNELS] -> [NUM_EXAMPLES, YPRED]

This is the loss: loss = tf.reduce_mean(tf.squared_difference(Ypreds, labels))

The training-loop:

Loop { 
    for i in range(EPOCHS):
        epoch_train_loss = 0

        for k in range(NUM_BATCHES):
            _, batch_loss = sess.run([train_step, loss], feed_dict={...})
            epoch_train_loss += (batch_loss/NUM_BATCHES)

        # calculate test loss after 1 epoch and log
        epoch_test_loss = sess.run(loss, feed_dict={...})

        # print train and test loss after 1 epoch
        print(epoch_train_loss, epoch_test_loss)
}

These are the logging results:

Epoch: 0 (8.21s), Train-Loss: 12844071, Test-Loss: 3802676
Epoch: 1 (4.94s), Train-Loss: 3691994, Test-Loss: 3562206
Epoch: 2 (4.90s), Train-Loss: 3315438, Test-Loss: 2968338
Epoch: 3 (5.00s), Train-Loss: 1841562, Test-Loss: 417192
Epoch: 4 (4.94s), Train-Loss: 164503, Test-Loss: 3531
Epoch: 5 (4.94s), Train-Loss: 97477, Test-Loss: 1843
Epoch: 6 (4.98s), Train-Loss: 96474, Test-Loss: 4676
Epoch: 7 (4.94s), Train-Loss: 89613, Test-Loss: 1080

Upvotes: 2

Views: 408

Answers (1)

Maxim
Maxim

Reputation: 53758

Your code seems fine, but I'd do it a bit differently:

epoch_train_losses = []
for k in range(NUM_BATCHES):
    _, batch_loss = sess.run([train_step, loss], feed_dict={...})
    epoch_train_losses.append(batch_loss)
epoch_train_loss = np.mean(epoch_train_losses)
# print `epoch_train_loss` and `epoch_train_losses` too

Getting the full distribution of the losses instead of a single number (the mean) can help you inspect in detail what's going on.

Here's one possible explanation: the training and test sets aren't properly shuffled, so that the test set effectively mimics the part of the training set (or can even be the part of the training set). In this case the distribution of training losses across batches will have very high variance: some losses will be comparable to the reported test loss, some losses will be much higher, pulling the mean up.

Upvotes: 1

Related Questions