Reputation: 302
I'm a newbie with PyTorch and adversarial networks. I've tried to look for an answer on the PyTorch documentation and from previous discussions both in the PyTorch and StackOverflow forums, but I couldn't find anything useful.
I'm trying to train a GAN with a Generator and a Discriminator, but I cannot understand if the whole process is working or not. As far as I'm concerned, I should train the Generator first and, then, updating the Discriminator's weights (similarly as this). My code for updating the weights of both models is:
# computing loss_g and loss_d...
optim_g.zero_grad()
loss_g.backward()
optim_g.step()
optim_d.zero_grad()
loss_d.backward()
optim_d.step()
where loss_g
is the generator loss, loss_d
is the discriminator loss, optim_g
is the optimizer referring to the generator's parameters and optim_d
is the discriminator optimizer.
If I run the code like this, I get an error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
So I specify loss_g.backward(retain_graph=True)
, and here comes my doubt: why should I specify retain_graph=True
if there are two networks with two different graphs? Am I getting something wrong?
Upvotes: 3
Views: 3025
Reputation: 32972
Having two different networks doesn't necessarily mean that the computational graph is different. The computational graph only tracks the operations that were performed from the input to the output and it doesn't matter where the operation takes place. In other words, if you use the output of the first model in the second model (e.g. model2(model1(input))
), you have the same sequential operations as if they were part of the same model. In fact, that is no different from having different parts of the model, such as multiple convolutions, that you apply one after the other.
The error you get, indicates that you are trying to backpropagate from the discriminator through the generator, which would mean that the discriminator's output directly adapts the generator's parameters for the discriminator to be successful. In an adversarial setting that is precisely what you want to avoid, they should be independent from each other. By setting retrain_graph=True
you incorrectly hide this bug. In nearly all cases retain_graph=True
is not the solution and should be avoided.
To resolve that issue, the two models need to be made independent from each other. The crossover between the two models happens when you use the generators output for the discriminator, since it should decide whether that was real or fake. Something along these lines:
fake = generator(noise)
real_prediction = discriminator(real)
# Using the output of the generator, continues the graph.
fake_prediction = discriminator(fake)
Even though fake
comes from the generator, as far as the discriminator is concerned, it's merely another input, just like real
. Therefore fake
should be treated the same as real
, where it is not attached to any computational graph. That can easily be done with torch.Tensor.detach
, which decouples the tensor from the graph.
fake = generator(noise)
real_prediction = discriminator(real)
# Detach to make it independent of the generator
fake_prediction = discriminator(fake.detach())
That is also done in the code you referenced, from erikqu/EnhanceNet-PyTorch - train.py:
hr_imgs = torch.cat([discriminator(hr), discriminator(generated_hr.detach())], dim=0)
Upvotes: 11