whitepanda
whitepanda

Reputation: 488

Using predict() method after the training process is completed in TensorFlow

I use a callback to stop the training after my loss goes below a certain value. Once the training is over, I call the predict() method on the training input, however, when I calculate the loss function manually, I get a quiet bad result. Is using predict() wrong? or am I doing something else wrong?

import numpy as np
import random
import tensorflow as tf
from sklearn.metrics import mean_squared_error as my_mse

class stopAtLossValue(tf.keras.callbacks.Callback):
        def on_epoch_end(self, batch, logs={}):
            eps = 0.00001 
            if logs.get('loss') <= eps:
                 self.model.stop_training = True

model = tf.keras.Sequential([  
    tf.keras.layers.Flatten(input_shape=(x.shape[1],)),
    tf.keras.layers.Dense(8, activation='relu'),
    tf.keras.layers.Dense(8, activation='relu'),
    tf.keras.layers.Dense(8, activation='relu'),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dense(1) 
])
model.compile(loss='mse',optimizer = tf.keras.optimizers.Adam(learning_rate=0.001))   
model.fit(x,  y, epochs=1000, batch_size=1, verbose=1, callbacks=[stopAtLossValue()])

For example, when I run the code snippet, I reach the loss value desired after 112 epochs.

Epoch 111/1000
20/20 [==============================] - 0s 2ms/step - loss: 0.0294
Epoch 112/1000
20/20 [==============================] - 0s 315us/step - loss: 1.0666e-06
<keras.callbacks.History at 0x153a7b70d30>

Then, I call the predict() method and compute the loss myself. By the way, my loss function is just a vanilla minimum squared error (MSE). The value that I get is quite high. In fact, if I print the predictions, they seem to be pretty bad even though tf stopped the training due to having a low MSE.

my_mse(y,model.predict(x))
0.027716089

Upvotes: 3

Views: 575

Answers (1)

Dr. Snoopy
Dr. Snoopy

Reputation: 56417

The difference is because the loss displayed during training is the mean of the per-batch losses, but for each batch there is a gradient update, so the model weights are changing during this, so you will not get the same loss values as evaluating using predict because the weights are fixed if you use predict.

So in the end, the numbers are not comparable because they are not computed in the same way.

Upvotes: 1

Related Questions