Reputation: 228
I'm trying to setup a simple GANs training loop but am getting the following error:
RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.
for epoch in range(N_EPOCHS):
# gets data for the generator
for i, batch in enumerate(dataloader, 0):
# passing target images to the Discriminator
global_disc.zero_grad()
output_disc = global_disc(batch.to(device))
error_target = loss(output_disc, torch.ones(output_disc.shape).cuda())
error_target.backward()
# apply mask to the images
batch = apply_mask(batch)
# passes fake images to the Discriminator
global_output, local_output = gen(batch.to(device))
output_disc = global_disc(global_output.detach())
error_fake = loss(output_disc, torch.zeros(output_disc.shape).to(device))
error_fake.backward()
# combines the errors
error_total = error_target + error_fake
optimizer_disc.step()
# updates the generator
gen.zero_grad()
error_gen = loss(output_disc, torch.ones(output_disc.shape).to(device))
error_gen.backward()
optimizer_gen.step()
break
break
As far as I can tell, I have the operations in the right order, I'm zeroing out the gradients, and I'm detaching the output of the generator before it goes into discriminator.
This article was helpful but I'm still running into something I don't understand.
Upvotes: 0
Views: 1486
Reputation: 40648
Two important points come to mind:
You should feed your generator with noise, and not the real input:
global_output, local_output = gen(noise.to(device))
Above noise
should have the appropriate shape (it is the input of your generator).
In order to optimize the generator, you are required to recompute the discriminator output, because it has already been backpropagated on. Simply add this line to recompute output_disc
:
# updates the generator
gen.zero_grad()
output_disc = global_disc(global_output)
# ...
Please refer to this tutorial provided by PyTorch for a full walkthrough.
Upvotes: 1