blz
blz

Reputation: 414

Adding noise when using embedding layer in pytorch

I'm building a generator g, that receives a latent-code (vector of shape 100) and outputs an image. Specifically, I have 1000 MNIST images, and I want the network to learn a latent code z_i for each image x_i, such that g(z_i)=x_i (this approach is known as Generative Latent Optimization). So I've used nn.Embedding(1000,embedding_dim=100) and standard generator architecture, that receives the code from embedding and outputs an image. As for loss, I combine reconstruction loss with regularization on the embedding-vector weights.

My probelm is: I'd like to add noise to the latent-code vector before it is inserted to the generator (in order to make the latent-code compact). However I'm a beginner, and I don't know whether I should call detach() when adding the noise or not. I'm not sure of my approach entirely. I don't want to learn the scale of the noise or anything. Here's my attempt:

class net(nn.Module):
  def __init__():
    self.embed = nn.Embedding(1000,embedding_dim=100)
    self.generator = nn.sequential( nn.Linear(100, 84), .... )
  def forward(batch_indices):
    batch_codes = self.embed(batch_indices)
    noise = torch.randn_like(batch_codes) * sigma
    noisy_batch_codes = batch_codes + noise # SHOULD THIS BE batch_codes.detach() + noise ??
    return self.generator(noisy_batch_codes)

g = net()
optim = SGD(g.parameters(), lr=0.01)
for epoch in range(num_epochs):
  for orig_images, orig_images_idx in trainloader:
    optim.zero_grad()
    output = g(orig_images_idx)
    reconstruction_loss = nn.MSELoss()(output, orig_images)
    embed_vector_weights = g.embed.weight[orig_images_idx]
    reg_loss = torch.norm(embed_vector_weights) * reg_coeff
    loss = reconstruction_loss + reg_loss
    loss.backward()
    optim.step()

Upvotes: 1

Views: 1406

Answers (1)

jodag
jodag

Reputation: 22224

If you detach before adding noise the gradients won't propagate to your encoder (the emedding layer in this case) so your encoder weights will never be updated. Therefore you should probably not detach if you want the encoder to learn.

Upvotes: 2

Related Questions