Aleksei Petrenko
Aleksei Petrenko

Reputation: 7168

Unable to train VAE with a deconvolutional layer

I was experimenting with a VAE implementation in Tensorflow for MNIST dataset. To start things off, I trained a VAE based on MLP encoder and decoder. It trains just fine, the loss decreases and it generates plausibly looking digits. Here's a code of the decoder of this MLP-based VAE:

x = sampled_z
x = tf.layers.dense(x, 200, tf.nn.relu)
x = tf.layers.dense(x, 200, tf.nn.relu)
x = tf.layers.dense(x, np.prod(data_shape))
img = tf.reshape(x, [-1] + data_shape)

As a next step, I decided to add convolutional layers. Changing just the encoder worked just fine, but when I use deconvolutions in the decoder (instead of fc layers) I don't get any training at all. The loss function never decreases, and the output is always black. Here's the code of deconvolutional decoder:

x = tf.layers.dense(sampled_z, 24, tf.nn.relu)
x = tf.layers.dense(x, 7 * 7 * 64, tf.nn.relu)
x = tf.reshape(x, [-1, 7, 7, 64])
x = tf.layers.conv2d_transpose(x, 64, 3, 2, 'SAME', activation=tf.nn.relu)
x = tf.layers.conv2d_transpose(x, 32, 3, 2, 'SAME', activation=tf.nn.relu)
x = tf.layers.conv2d_transpose(x, 1, 3, 1, 'SAME', activation=tf.nn.sigmoid)
img = tf.reshape(x, [-1, 28, 28])

This seems bizarre, the code looks just fine to me. I narrowed it down to the deconvolutional layers in the decoder, there's something in there that breaks it. E.g. if I add a fully-connected layer (even without the nonlinearity!) after the last deconvolution, it works again! Here's the code:

x = tf.layers.dense(sampled_z, 24, tf.nn.relu)
x = tf.layers.dense(x, 7 * 7 * 64, tf.nn.relu)
x = tf.reshape(x, [-1, 7, 7, 64])
x = tf.layers.conv2d_transpose(x, 64, 3, 2, 'SAME', activation=tf.nn.relu)
x = tf.layers.conv2d_transpose(x, 32, 3, 2, 'SAME', activation=tf.nn.relu)
x = tf.layers.conv2d_transpose(x, 1, 3, 1, 'SAME', activation=tf.nn.sigmoid)
x = tf.contrib.layers.flatten(x)
x = tf.layers.dense(x, 28 * 28)
img = tf.reshape(x, [-1, 28, 28])

I'm really a little stuck at this point, does anyone have any idea what might be happening here? I use tf 1.8.0, Adam optimizer, 1e-4 learning rate.

EDIT:

As @Agost pointed out, I should perhaps clarify things about my loss function and the training process. I model the posterior as a Bernoulli distribution and maximizing ELBO as my loss. Inspired by this post. Here's the full code of encoder, decoder, and the loss:

def make_prior():
    mu = tf.zeros(N_LATENT)
    sigma = tf.ones(N_LATENT)
    return tf.contrib.distributions.MultivariateNormalDiag(mu, sigma)


def make_encoder(x_input):
    x_input = tf.reshape(x_input, shape=[-1, 28, 28, 1])
    x = conv(x_input, 32, 3, 2)
    x = conv(x, 64, 3, 2)
    x = conv(x, 128, 3, 2)
    x = tf.contrib.layers.flatten(x)
    mu = dense(x, N_LATENT)
    sigma = dense(x, N_LATENT, activation=tf.nn.softplus)  # softplus is log(exp(x) + 1)
    return tf.contrib.distributions.MultivariateNormalDiag(mu, sigma)    


def make_decoder(sampled_z):
    x = tf.layers.dense(sampled_z, 24, tf.nn.relu)
    x = tf.layers.dense(x, 7 * 7 * 64, tf.nn.relu)
    x = tf.reshape(x, [-1, 7, 7, 64])

    x = tf.layers.conv2d_transpose(x, 64, 3, 2, 'SAME', activation=tf.nn.relu)
    x = tf.layers.conv2d_transpose(x, 32, 3, 2, 'SAME', activation=tf.nn.relu)
    x = tf.layers.conv2d_transpose(x, 1, 3, 1, 'SAME')

    img = tf.reshape(x, [-1, 28, 28])

    img_distribution = tf.contrib.distributions.Bernoulli(img)
    img = img_distribution.probs
    img_distribution = tf.contrib.distributions.Independent(img_distribution, 2)
    return img, img_distribution


def main():
    mnist = input_data.read_data_sets(os.path.join(experiment_dir(EXPERIMENT), 'MNIST_data'))

    tf.reset_default_graph()

    batch_size = 128

    x_input = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name='X')

    prior = make_prior()
    posterior = make_encoder(x_input)

    mu, sigma = posterior.mean(), posterior.stddev()

    z = posterior.sample()
    generated_img, output_distribution = make_decoder(z)

    likelihood = output_distribution.log_prob(x_input)
    divergence = tf.distributions.kl_divergence(posterior, prior)
    elbo = tf.reduce_mean(likelihood - divergence)
    loss = -elbo

    global_step = tf.train.get_or_create_global_step()
    optimizer = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step)

Upvotes: 2

Views: 238

Answers (1)

Burton2000
Burton2000

Reputation: 2072

Could it be your use of sigmoid in the final deconv layer restricting output to 0-1, you dont do this in the MLP based autoencoder or when adding a fully-connected after the deconvs so possible data range issue?

Upvotes: 2

Related Questions