Reputation: 491
I would like to know how to take gradient steps for the following mathematical operation in PyTorch (A, B and C are PyTorch modules whose parameters do not overlap)
This is somewhat different than the cost function of a Generative Adversarial Network (GAN), so I cannot use examples for GANs off the shelf, and I got stuck while trying to adapt them for the above cost.
One approach I thought of is to construct two optimizers. Optimizer opt1
has the parameters for the modules A and B, and optimizer opt2
has the parameters of module C. One can then:
I am sure they must be a better way to do this with PyTorch (maybe using some detach
operations), possibly without running the network again. Any help is appreciated.
Upvotes: 3
Views: 845
Reputation: 1131
Yes it is possible without going through the network two times, which is both wasting resources and wrong mathematically, since the weights have changed and so the lost, so you are introducing a delay doing this, which may be interesting but not what you are trying to achieve.
First, create two optimizers just as you said. Compute the loss, and then call backward
. At this point, the gradient for the parameters A,B,C have been filled, so now you can just have to call the step
method for the optimizer minimizing the loss, but not for the one maximizing it. For the later, you need to reverse the sign of the gradient of the leaf parameter tensor C.
def d(y, x):
return torch.pow(y.abs(), x + 1)
A = torch.nn.Linear(1,2)
B = torch.nn.Linear(2,3)
C = torch.nn.Linear(2,3)
optimizer1 = torch.optim.Adam((*A.parameters(), *B.parameters()))
optimizer2 = torch.optim.Adam(C.parameters())
x = torch.rand((10, 1))
loss = (d(B(A(x)), x) - d(C(A(x)), x)).sum()
optimizer1.zero_grad()
optimizer2.zero_grad()
loss.backward()
for p in C.parameters():
if p.grad is not None: # In general, C is a NN, with requires_grad=False for some layers
p.grad.data.mul_(-1) # Update of grad.data not tracked in computation graph
optimizer1.step()
optimizer2.step()
NB: I have not checked mathematically if the result is correct but I assume it is.
Upvotes: 2