Reputation: 9
What could be the issue of kl loss going to 0? reconstruction loss is small, but every image is the same, and does not represent any digit.
Here is my encoder/decoder architecture I used, I think the architecture is complex enough to effectively capture and create new digits.
import tensorflow as tf
from keras import layers
import gumbel_softmax
class VariationalAutoEncoder(tf.keras.Model):
def __init__(self, latent_dim, categorical_dim):
super(VariationalAutoEncoder, self).__init__()
self.latent_dim = latent_dim
self.categorical_dim = categorical_dim
self.z_dim = latent_dim * categorical_dim
self.encoder = tf.keras.Sequential([
layers.InputLayer(input_shape=(28, 28, 1)),
layers.Conv2D(16, (3, 3), activation='relu', strides=2, padding='same'),
layers.Conv2D(32, (3, 3), activation='relu', strides=2, padding='same'),
layers.Conv2D(64, (3, 3), activation='relu', strides=2, padding='same'),
layers.BatchNormalization(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(self.z_dim),
layers.Reshape((latent_dim, categorical_dim)),
])
self.decoder = tf.keras.Sequential([
layers.InputLayer(input_shape=(latent_dim, categorical_dim)),
layers.Flatten(),
layers.Dense(7 * 7 * 64, activation='relu'),
layers.Reshape((7, 7, 64)),
layers.Conv2DTranspose(64, (3, 3), activation='relu', strides= 2, padding='same'),
layers.Conv2DTranspose(32, (3, 3), activation='relu',strides = 2,padding='same'),
layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same'),
])
def call(self, x, temperature, hard):
logits = self.encoder(x)
z = gumbel_softmax.gumbel_softmax(logits, temperature, hard=hard)
reconstructed = self.decoder(z)
return tf.reshape(reconstructed,(reconstructed.shape[0],28,28)), tf.nn.softmax(logits, axis=-1)
import tensorflow as tf
def sample_gumbel(shape, eps=1e-15):
"""Sample from Gumbel(0, 1) distribution."""
U = tf.random.uniform(shape, minval=0, maxval=1)
return -tf.math.log(-tf.math.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
"""Sampling from the Gumbel-Softmax distribution."""
y = logits + sample_gumbel(tf.shape(logits))
return tf.nn.softmax(y / temperature)
def gumbel_softmax(logits, temperature, hard=False):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probability distribution that sums to 1 across classes.
"""
y = gumbel_softmax_sample(logits, temperature)
if hard:
y_hard = tf.one_hot(tf.argmax(y, axis=-1), tf.shape(logits)[-1])
y = tf.stop_gradient(y_hard - y) + y
return y
Here are the parameters I used, I tried experimenting a lot with them, but there was no success.
LR_RATE = 5e-3
BATCH_SIZE=64
NUM_ITERS=900
tau0 = 1
ANNEAL_RATE=5e-5
MIN_TEMP=0.1
EPOCHS = 30
CATEGORICAL_DIM = 10
LATENT_DIM = 32
I calculate the kl loss and reconstruction loss summed over batch_size
def compute_loss(model, x, temperature, hard,beta):
reconstructed, logits = model(x, temperature = temperature, hard = hard)
mse = keras.losses.BinaryCrossentropy(reduction="sum_over_batch_size")
reconstruction_loss = mse(x,reconstructed)
kl_loss = tf.reduce_mean(
tf.reduce_sum(
logits * (tf.math.log(logits + 1e-8) - tf.math.log(1.0 / config.CATEGORICAL_DIM)),
axis=[1, 2]
)
)
total_loss = reconstruction_loss + beta * kl_loss
return total_loss, reconstruction_loss, beta * kl_loss, reconstructed
Upvotes: 0
Views: 20