Reputation: 7168
I was experimenting with a VAE implementation in Tensorflow for MNIST dataset. To start things off, I trained a VAE based on MLP encoder and decoder. It trains just fine, the loss decreases and it generates plausibly looking digits. Here's a code of the decoder of this MLP-based VAE:
x = sampled_z
x = tf.layers.dense(x, 200, tf.nn.relu)
x = tf.layers.dense(x, 200, tf.nn.relu)
x = tf.layers.dense(x, np.prod(data_shape))
img = tf.reshape(x, [-1] + data_shape)
As a next step, I decided to add convolutional layers. Changing just the encoder worked just fine, but when I use deconvolutions in the decoder (instead of fc layers) I don't get any training at all. The loss function never decreases, and the output is always black. Here's the code of deconvolutional decoder:
x = tf.layers.dense(sampled_z, 24, tf.nn.relu)
x = tf.layers.dense(x, 7 * 7 * 64, tf.nn.relu)
x = tf.reshape(x, [-1, 7, 7, 64])
x = tf.layers.conv2d_transpose(x, 64, 3, 2, 'SAME', activation=tf.nn.relu)
x = tf.layers.conv2d_transpose(x, 32, 3, 2, 'SAME', activation=tf.nn.relu)
x = tf.layers.conv2d_transpose(x, 1, 3, 1, 'SAME', activation=tf.nn.sigmoid)
img = tf.reshape(x, [-1, 28, 28])
This seems bizarre, the code looks just fine to me. I narrowed it down to the deconvolutional layers in the decoder, there's something in there that breaks it. E.g. if I add a fully-connected layer (even without the nonlinearity!) after the last deconvolution, it works again! Here's the code:
x = tf.layers.dense(sampled_z, 24, tf.nn.relu)
x = tf.layers.dense(x, 7 * 7 * 64, tf.nn.relu)
x = tf.reshape(x, [-1, 7, 7, 64])
x = tf.layers.conv2d_transpose(x, 64, 3, 2, 'SAME', activation=tf.nn.relu)
x = tf.layers.conv2d_transpose(x, 32, 3, 2, 'SAME', activation=tf.nn.relu)
x = tf.layers.conv2d_transpose(x, 1, 3, 1, 'SAME', activation=tf.nn.sigmoid)
x = tf.contrib.layers.flatten(x)
x = tf.layers.dense(x, 28 * 28)
img = tf.reshape(x, [-1, 28, 28])
I'm really a little stuck at this point, does anyone have any idea what might be happening here? I use tf 1.8.0, Adam optimizer, 1e-4 learning rate.
EDIT:
As @Agost pointed out, I should perhaps clarify things about my loss function and the training process. I model the posterior as a Bernoulli distribution and maximizing ELBO as my loss. Inspired by this post. Here's the full code of encoder, decoder, and the loss:
def make_prior():
mu = tf.zeros(N_LATENT)
sigma = tf.ones(N_LATENT)
return tf.contrib.distributions.MultivariateNormalDiag(mu, sigma)
def make_encoder(x_input):
x_input = tf.reshape(x_input, shape=[-1, 28, 28, 1])
x = conv(x_input, 32, 3, 2)
x = conv(x, 64, 3, 2)
x = conv(x, 128, 3, 2)
x = tf.contrib.layers.flatten(x)
mu = dense(x, N_LATENT)
sigma = dense(x, N_LATENT, activation=tf.nn.softplus) # softplus is log(exp(x) + 1)
return tf.contrib.distributions.MultivariateNormalDiag(mu, sigma)
def make_decoder(sampled_z):
x = tf.layers.dense(sampled_z, 24, tf.nn.relu)
x = tf.layers.dense(x, 7 * 7 * 64, tf.nn.relu)
x = tf.reshape(x, [-1, 7, 7, 64])
x = tf.layers.conv2d_transpose(x, 64, 3, 2, 'SAME', activation=tf.nn.relu)
x = tf.layers.conv2d_transpose(x, 32, 3, 2, 'SAME', activation=tf.nn.relu)
x = tf.layers.conv2d_transpose(x, 1, 3, 1, 'SAME')
img = tf.reshape(x, [-1, 28, 28])
img_distribution = tf.contrib.distributions.Bernoulli(img)
img = img_distribution.probs
img_distribution = tf.contrib.distributions.Independent(img_distribution, 2)
return img, img_distribution
def main():
mnist = input_data.read_data_sets(os.path.join(experiment_dir(EXPERIMENT), 'MNIST_data'))
tf.reset_default_graph()
batch_size = 128
x_input = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name='X')
prior = make_prior()
posterior = make_encoder(x_input)
mu, sigma = posterior.mean(), posterior.stddev()
z = posterior.sample()
generated_img, output_distribution = make_decoder(z)
likelihood = output_distribution.log_prob(x_input)
divergence = tf.distributions.kl_divergence(posterior, prior)
elbo = tf.reduce_mean(likelihood - divergence)
loss = -elbo
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step)
Upvotes: 2
Views: 238
Reputation: 2072
Could it be your use of sigmoid in the final deconv layer restricting output to 0-1, you dont do this in the MLP based autoencoder or when adding a fully-connected after the deconvs so possible data range issue?
Upvotes: 2