zion
zion

Reputation: 9

VAE with Gumbel softmax on MNIST dataset

What could be the issue of kl loss going to 0? reconstruction loss is small, but every image is the same, and does not represent any digit.

Here is my encoder/decoder architecture I used, I think the architecture is complex enough to effectively capture and create new digits.

import tensorflow as tf
from keras import layers
import gumbel_softmax

class VariationalAutoEncoder(tf.keras.Model):
    def __init__(self, latent_dim, categorical_dim):
        super(VariationalAutoEncoder, self).__init__()
        self.latent_dim = latent_dim
        self.categorical_dim = categorical_dim
        self.z_dim = latent_dim * categorical_dim

        self.encoder = tf.keras.Sequential([
            layers.InputLayer(input_shape=(28, 28, 1)),
            layers.Conv2D(16, (3, 3), activation='relu', strides=2, padding='same'),
            layers.Conv2D(32, (3, 3), activation='relu', strides=2, padding='same'),
            layers.Conv2D(64, (3, 3), activation='relu', strides=2, padding='same'),
            layers.BatchNormalization(),
            layers.Flatten(),
            layers.Dense(128, activation='relu'),
            layers.Dense(self.z_dim),
            layers.Reshape((latent_dim, categorical_dim)),
        ])


        self.decoder = tf.keras.Sequential([
            layers.InputLayer(input_shape=(latent_dim, categorical_dim)),
            layers.Flatten(),
            layers.Dense(7 * 7 * 64, activation='relu'),
            layers.Reshape((7, 7, 64)),
            layers.Conv2DTranspose(64, (3, 3), activation='relu', strides= 2, padding='same'),
            layers.Conv2DTranspose(32, (3, 3), activation='relu',strides = 2,padding='same'),
            layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same'),
        ])

    def call(self, x, temperature, hard):
        logits = self.encoder(x)
        z = gumbel_softmax.gumbel_softmax(logits, temperature, hard=hard)
        reconstructed = self.decoder(z)
        return tf.reshape(reconstructed,(reconstructed.shape[0],28,28)), tf.nn.softmax(logits, axis=-1)

import tensorflow as tf

def sample_gumbel(shape, eps=1e-15): 
    """Sample from Gumbel(0, 1) distribution."""
    U = tf.random.uniform(shape, minval=0, maxval=1)
    return -tf.math.log(-tf.math.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature): 
    """Sampling from the Gumbel-Softmax distribution."""
    y = logits + sample_gumbel(tf.shape(logits))
    return tf.nn.softmax(y / temperature)

def gumbel_softmax(logits, temperature, hard=False):
    """Sample from the Gumbel-Softmax distribution and optionally discretize.
    
    Args:
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
        
    Returns:
        [batch_size, n_class] sample from the Gumbel-Softmax distribution.
        If hard=True, then the returned sample will be one-hot, otherwise it will
        be a probability distribution that sums to 1 across classes.
    """
    y = gumbel_softmax_sample(logits, temperature) 
    if hard: 
        y_hard = tf.one_hot(tf.argmax(y, axis=-1), tf.shape(logits)[-1])
        y = tf.stop_gradient(y_hard - y) + y 
    
    return y

Here are the parameters I used, I tried experimenting a lot with them, but there was no success.

LR_RATE = 5e-3

BATCH_SIZE=64
NUM_ITERS=900
tau0 = 1
ANNEAL_RATE=5e-5
MIN_TEMP=0.1
EPOCHS = 30

CATEGORICAL_DIM = 10
LATENT_DIM = 32

I calculate the kl loss and reconstruction loss summed over batch_size

def compute_loss(model, x, temperature, hard,beta):
    reconstructed, logits = model(x, temperature = temperature, hard = hard)

    mse = keras.losses.BinaryCrossentropy(reduction="sum_over_batch_size")

    reconstruction_loss = mse(x,reconstructed)
    kl_loss = tf.reduce_mean(
    tf.reduce_sum(
        logits * (tf.math.log(logits + 1e-8) - tf.math.log(1.0 / config.CATEGORICAL_DIM)),
        axis=[1, 2]
    )
)
    total_loss = reconstruction_loss + beta * kl_loss
    return total_loss, reconstruction_loss, beta * kl_loss, reconstructed

Upvotes: 0

Views: 20

Answers (0)

Related Questions