Reputation: 23
I have two neural networks in torch that are nested and I am computing multiple losses across the output with respect to different parameters. Below is a simple case
# two neural networks
>>> A = nn.Linear(10,10)
>>> B = nn.Linear(10,1)
# dummy input
>>> x = torch.rand(1,10, requires_grad=True)
# nested computation
>>> y = B(A(x))
# evaluate two separate Loss functions on the output
>>> Loss1 = f(y)
>>> Loss2 = g(y)
# evaluate backprop through both losses
>>> (Loss1+Loss2).backward()
I would like for Loss1 to track the gradient changes of network A and B together, but would like Loss2 to only track the changes with respect to network A. I know I can compute this by breaking the computation into two back propagation steps like
# two neural networks
>>> A = nn.Linear(10,10)
>>> B = nn.Linear(10,1)
# dummy input
>>> x = torch.rand(1,10, requires_grad=True)
# nested computation
>>> y = B(A(x))
# evaluate first loss function
>>> Loss1 = f(y)
# evaluate backprop through first loss
>>> Loss1.backward()
# disable gradient computation on B
>>> B.requires_grad_(False)
# nested computation
>>> y = B(A(x))
# evaluate second loss function
>>> Loss2 = g(y)
# evaluate backprop through second loss
>>> Loss2.backward()
I am do not like this approach as it requires multiple backpropagation computations through the nested neural networks. Is there a way to mark the second loss to not update network B? I am thinking something similar to g(y).detach()
however this also removes the gradients with respect to network A.
Upvotes: 0
Views: 55
Reputation: 40708
You are describing something similar to a GAN optimization approach, where A
would be the generator, and B
the discriminator. So it's good to compare how it is done with GANs in such a framework as PyTorch. You can't separate two gradient signals with a single backward pass. You must have two backward passes.
|<------------------- L1
|<---------•••••••••• L2
x ---> A ---> B ---> y
Upvotes: 0