user3927666
user3927666

Reputation: 39

VAE in Keras to visualize latent space on 3 classes of images

I am training a Variational Auto-Encoder (VAE) with unlabelled input images. My interest here is to visualize the 3 classes of unlabelled data in the latent space. I set the latent dimension to 128 and further use PCA to visualize in 2D.

I am new to this and seeking some clarity on this. Firstly, while I train the network, I see the accuracy and validation accuracy being displayed. Since the input images to the network are not labeled, I wonder what exactly is accuracy computed based on. (According to what I have read, accuracy = number of samples correctly predicted/ total number of samples).

Secondly, my training code looks like this:

    vae.compile(optimizer='rmsprop', loss=kl_reconstruction_loss, metrics=['accuracy'])
    history=vae.fit_generator(X_train, epochs=15,
                    validation_data=next(X_val), validation_steps=5,
                    callbacks=[ReduceLROnPlateau(monitor='val_loss', factor=0.5, verbose=2,
                    patience=4, cooldown=1, min_lr=0.0001)])

During training, validation loss goes to zero and accuracy to 1 in 2 epochs. Here, for three classes of unlabelled data, the trained network does not very well cluster 3 different classes. Not very clear about why is accuracy shooting up to 1 and loss to 0 while the network is not able to generalize very well on the test dataset

Epoch 1/2
2164/2164 [==============================] - 872s 403ms/step - loss: 6668.9662 - accuracy: 0.7253 - val_loss: 3785921.0000 - val_accuracy: 0.9982
Epoch 2/2
2164/2164 [==============================] - 869s 401ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 3857374.2500 - val_accuracy: 0.9983

Any insights/suggestions?

Upvotes: 1

Views: 744

Answers (1)

Tom C
Tom C

Reputation: 610

The purpose of a VAE is to compress the input into a well-behaved latent representation and then use this latent representation to accurately reconstruct the input. There are two terms in the loss for a VAE, each one corresponding to one of the tasks in the first sentence. The first term in the loss - the KL divergence term - forces the latent representation to be drawn from a multidimensional unit Gaussian distribution. The second term in the loss makes the VAE accurately reconstruct the input. Typically people use something like the L2 loss between the input and the output (there are more fancy things to use, but L2 loss usually does OK).

Since the model is not performing classification, accuracy is not a good metric to use. A better way to see how accurately the VAE is reconstructing the input is to monitor something like mean squared error.

The output of your code (where training loss goes to 0 after 1 epoch) indicates that the model is overfitting the training data. Try regularizing your model (or using a less capacious one) or reducing the number of steps per training epoch so you can monitor the performance of your model on the validation data more often.

Also, your usage of next(X_val) is just grabbing one element from your validation set. You probably want to pass in the validation data generator, not a single element from the validation data. Removing the next() call will achieve that.

Upvotes: 1

Related Questions