dagcilibili
dagcilibili

Reputation: 491

Minimization and maximization at the same time in PyTorch

I would like to know how to take gradient steps for the following mathematical operation in PyTorch (A, B and C are PyTorch modules whose parameters do not overlap)

math

This is somewhat different than the cost function of a Generative Adversarial Network (GAN), so I cannot use examples for GANs off the shelf, and I got stuck while trying to adapt them for the above cost.

One approach I thought of is to construct two optimizers. Optimizer opt1 has the parameters for the modules A and B, and optimizer opt2 has the parameters of module C. One can then:

  1. take a step for minimizing the cost function for C
  2. run the network again with the same input to get the costs (and intermediate outputs) again
  3. take a step with respect to A and B.

I am sure they must be a better way to do this with PyTorch (maybe using some detach operations), possibly without running the network again. Any help is appreciated.

Upvotes: 3

Views: 845

Answers (1)

duburcqa
duburcqa

Reputation: 1131

Yes it is possible without going through the network two times, which is both wasting resources and wrong mathematically, since the weights have changed and so the lost, so you are introducing a delay doing this, which may be interesting but not what you are trying to achieve.

First, create two optimizers just as you said. Compute the loss, and then call backward. At this point, the gradient for the parameters A,B,C have been filled, so now you can just have to call the step method for the optimizer minimizing the loss, but not for the one maximizing it. For the later, you need to reverse the sign of the gradient of the leaf parameter tensor C.

def d(y, x):
    return torch.pow(y.abs(), x + 1)

A = torch.nn.Linear(1,2)
B = torch.nn.Linear(2,3)
C = torch.nn.Linear(2,3)

optimizer1 = torch.optim.Adam((*A.parameters(), *B.parameters()))
optimizer2 = torch.optim.Adam(C.parameters())

x = torch.rand((10, 1))
loss = (d(B(A(x)), x) - d(C(A(x)), x)).sum()

optimizer1.zero_grad()
optimizer2.zero_grad()

loss.backward()
for p in C.parameters(): 
    if p.grad is not None: # In general, C is a NN, with requires_grad=False for some layers
        p.grad.data.mul_(-1) # Update of grad.data not tracked in computation graph

optimizer1.step()
optimizer2.step()

NB: I have not checked mathematically if the result is correct but I assume it is.

Upvotes: 2

Related Questions