natoucs
natoucs

Reputation: 103

Implementation of VQ-VAE-2 paper

I am trying to build a 2 stage VQ-VAE-2 + PixelCNN as shown in the paper: "Generating Diverse High-Fidelity Images with VQ-VAE-2" (https://arxiv.org/pdf/1906.00446.pdf). I have 3 implementation questions:

  1. The paper mentions:

    We allow each level in the hierarchy to separately depend on pixels.

I understand the second latent space in the VQ-VAE-2 must be conditioned on a concatenation of the 1st latent space and a downsampled version of the image. Is that correct ?

  1. The paper "Conditional Image Generation with PixelCNN Decoders" (https://papers.nips.cc/paper/6527-conditional-image-generation-with-pixelcnn-decoders.pdf) says:

    h is a one-hot encoding that specifies a class this is equivalent to adding a class dependent bias at every layer.

As I understand it, the condition is entered as a 1D tensor that is injected into the bias through a convolution. Now for a 2 stage conditional PixelCNN, one needs to condition on the class vector but also on the latent code of the previous stage. A possibility I see is to append them and feed a 3D tensor. Does anyone see another way to do this ?

  1. The loss and optimization are unchanged in 2 stages. One simply adds the loss of each stage into a final loss that is optimized. Is that correct ?

Upvotes: 2

Views: 1476

Answers (1)

natoucs
natoucs

Reputation: 103

Discussing with one of the author of the paper, I received answers to all those questions and shared them below.

Question 1

This is correct, but the downsampling of the image is implemented with strided convolution rather than a non-parametric resize. This can be absorbed as part of the encoder architecture in something like this (the number after each variable indicates their spatial dim, so for example h64 is [B, 64, 64, D] and so on).

   h128 = Relu(Conv2D(image256, stride=(2, 2)))
   h64 = Relu(Conv2D(h128, stride=(2, 2)))
   h64 = ResNet(h64)

Now for obtaining h32 and q32 we can do:

   h32 = Relu(Conv2D(h64,  stride=(2, 2)))
   h32 = ResNet(h32)
   q32 = Quantize(h32)

This way, the gradients flow all the way back to the image and hence we have a dependency between h32 and image256.

Everywhere you can use 1x1 convolution to adjust the size of the last dimension (the feature layers), use strided convolution for down-sampling and strided transposed convolution for upsampling spatial dimensions. So for this example of quantizing bottom layer, you need to first upsample q32 spatially to become 64x64 and combine it with h64 and feed the result to the quantizer. For additional expressive power we inserted a residual stack in between as well. It looks like this:

    hq32 = ResNet(Conv2D(q32, (1, 1)))
    hq64 = Conv2DTranspose(hq32, stride=(2, 2))
    h64 = Conv2D(concat([h64, hq64]), (1, 1))
    q64 = Quantize(h64)

Question 2

The original PixelCNN paper also describes how to use spatial conditioning using convolutions. Flattening and appending to class embedding as a global conditioning is not a good idea. What you would want to do is to apply a transposed convolution to align the spatial dimensions, then a 1x1 convolution to match the feature dimension with hidden reps of pixelcnn and then add it.

Question 3

It's a good idea to train them separately. Besides isolating the losses etc. and being able to tune appropriate learning rates for each stage, you will also be able to use the full memory capacity of your GPU/TPU for each stage. These priors do better and better with larger scale, so it's a good idea to not deny them of that.

Upvotes: 5

Related Questions