Lucas
Lucas

Reputation: 63

How to implement validation loss in custom training loop?

I've been trying to get early stopping to work on an LSTM VAE. During training training loss is computed as it should, however validation loss is 0. I tried to write a custom val_step function (similar to train_step but without trackers) to compute the loss but I think I think I'm failing to establish the connection between that function and the validation_data argument in the vae.fit() call. The custom model class is shown below:

class VAE(Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = tf.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = tf.metrics.Mean(name="kl_loss")

    def call(self, x):
        _, _, z = self.encoder(x)
        return self.decoder(z)

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(tf.reduce_sum(losses.mse(data, reconstruction), axis=1))
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def val_step(self, validation_data):
        _, _, z = self.encoder(validation_data)
        val_reconstruction = self.decoder(z)
        val_reconstruction_loss = tf.reduce_mean(tf.reduce_sum(losses.mse(validation_data, val_reconstruction), axis=1))
        val_kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        val_kl_loss = tf.reduce_mean(tf.reduce_sum(val_kl_loss, axis=1))
        val_total_loss = val_reconstruction_loss + val_kl_loss
        return {"total_loss": self.val_total_loss}


es = callbacks.EarlyStopping(monitor='val_total_loss',
                             mode='min',
                             verbose=1,
                             patience=5,
                             restore_best_weights=True,
                             )

vae = VAE(encoder, decoder)
vae.compile(optimizer=tf.optimizers.Adam())

vae.fit(tf_train,
        epochs=100,
        callbacks=[es],
        validation_data=tf_val,
        shuffle=True
        )

This is what the console prints out after every epoch (validation metrics show 0):

38/38 [==============================] - 37s 731ms/step - loss: 3676.8105 - reconstruction_loss: 2402.6206 - kl_loss: 149.5690 - val_total_loss: 0.0000e+00 - val_reconstruction_loss: 0.0000e+00 - val_kl_loss: 0.0000e+00

It'd be great if anyone could tell me what I'm doing wrong. Thank you in advance!

Update 1: Removed 'val_' from the return in the val_step definition. Interestingly the val_total_loss in the line before the return call is greyed out, because it is not used. So it looks like there is a disconnection between those two lines.

Upvotes: 4

Views: 2071

Answers (2)

Flyingmars
Flyingmars

Reputation: 196

I think your code may be modified from the Keras VAE example code. I also struggled with adding the val_loss with the example code, and here is the solution that works for me.

Keras seems to raise an error when the validation_data is a tuple with length less than 2, so I modify the validation_data as follow,

vae.fit(
    tf_train,
    epochs=100,
    callbacks=[es],
    validation_data=(valid_data,valid_data),  # <-- input X twice
    shuffle=True
)

As modified above, we will need to separate the X and y after receiving the argument of the test_step. Also noticed that the value returned is the val_total_loss instead of the self.val_total_loss

def test_step(self, input_data):
    validation_data, _ = input_data # <-- Seperate X and y
    z_mean, z_log_var, z = self.encoder(validation_data)
    val_reconstruction = self.decoder(z)
    val_reconstruction_loss = tf.reduce_mean(tf.reduce_sum(losses.mse(validation_data, val_reconstruction), axis=1))
    val_kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
    val_kl_loss = tf.reduce_mean(tf.reduce_sum(val_kl_loss, axis=1))
    val_total_loss = val_reconstruction_loss + val_kl_loss
    return {"total_loss": val_total_loss} # <-- modify the return value here

The logs while training will be like

Epoch 00018: val_loss improved from 2304.90210 to 2304.70728, saving model to ./best_model.h5
Epoch 19/10000
31/31 [==============================] - 0s 11ms/step - loss: 2325.7858 - reconstruction_loss: 2318.3337 - kl_loss: 4.9127 - val_total_loss: 2303.8118

Hope this helps :)

Upvotes: 4

geometrikal
geometrikal

Reputation: 3294

The tensorflow keras fit function automatically appends "val_" to the validation losses.

Try just returnung "total_loss" instead, e.g.,

return {"total_loss": self.val_total_loss}

Edit:

Also you are setting val_total_loss but returning self.val_total_loss

Upvotes: 0

Related Questions