cdmath
cdmath

Reputation: 23

Locally blocking gradient update for nested neural network

I have two neural networks in torch that are nested and I am computing multiple losses across the output with respect to different parameters. Below is a simple case

# two neural networks
>>> A = nn.Linear(10,10)
>>> B = nn.Linear(10,1)

# dummy input
>>> x = torch.rand(1,10, requires_grad=True)

# nested computation
>>> y = B(A(x))

# evaluate two separate Loss functions on the output
>>> Loss1 = f(y)
>>> Loss2 = g(y)

# evaluate backprop through both losses
>>> (Loss1+Loss2).backward()

I would like for Loss1 to track the gradient changes of network A and B together, but would like Loss2 to only track the changes with respect to network A. I know I can compute this by breaking the computation into two back propagation steps like

# two neural networks
>>> A = nn.Linear(10,10)
>>> B = nn.Linear(10,1)

# dummy input
>>> x = torch.rand(1,10, requires_grad=True)

# nested computation
>>> y = B(A(x))

# evaluate first loss function
>>> Loss1 = f(y)

# evaluate backprop through first loss
>>> Loss1.backward()

# disable gradient computation on B
>>> B.requires_grad_(False)

# nested computation
>>> y = B(A(x))

# evaluate second loss function
>>> Loss2 = g(y)

# evaluate backprop through second loss
>>> Loss2.backward()

I am do not like this approach as it requires multiple backpropagation computations through the nested neural networks. Is there a way to mark the second loss to not update network B? I am thinking something similar to g(y).detach() however this also removes the gradients with respect to network A.

Upvotes: 0

Views: 55

Answers (1)

Ivan
Ivan

Reputation: 40708

You are describing something similar to a GAN optimization approach, where A would be the generator, and B the discriminator. So it's good to compare how it is done with GANs in such a framework as PyTorch. You can't separate two gradient signals with a single backward pass. You must have two backward passes.

|<------------------- L1 
|<---------•••••••••• L2
x ---> A ---> B ---> y 

Upvotes: 0

Related Questions