Reputation: 2957
I am trying to use the gradient of a network with respect to its inputs as part of my loss function. However, whenever I try to calculate it, the training proceeds but the weights do not update
import torch
import torch.optim as optim
import torch.autograd as autograd
ic = torch.rand((25, 3))
ic = torch.tensor(ic, requires_grad=True)
optimizer = optim.RMSprop([ic], lr=1e-2)
for itr in range(1, 50):
optimizer.zero_grad()
sol = torch.tanh(.5*torch.stack(100*[ic])) # simplified for minimal working example
dx = sol[-1, :, 0]
dxdxy, = autograd.grad(dx,
inputs=ic,
grad_outputs = torch.ones(ic.shape[0]), # batchwise
retain_graph=True
)
dxdxy = torch.tensor(dxdxy, requires_grad=True)
loss = torch.sum(dxdxy)
loss.backward()
optimizer.step()
if itr % 5 == 0:
print(loss)
What am I doing wrong?
Upvotes: 2
Views: 1134
Reputation: 2493
When you run autograd.grad
without setting flag create_graph
to True
then you won't obtain an output which is connected to the computational graph, which means that you won't be able to further optimize w.r.t ic
(and obtain a higher order derivative as you wish to do here).
From torch.autograd.grad
's docstring:
create_graph (bool, optional): If
True
, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default:False
.
Using dxdxy = torch.tensor(dxdxy, requires_grad=True)
as you've tried here won't help since the computational graph which is connected to ic
has been lost by then (since create_graph
is False
), and all you do is create a new computational graph with dxdxy
being a leaf node.
See the solution attached below (note that when you create ic
you can set requires_grad=True
and hence the second line is redundant (that's not a logical problem but just longer code):
import torch
import torch.optim as optim
import torch.autograd as autograd
ic = torch.rand((25, 3),requires_grad=True) #<-- requires_grad to True here
#ic = torch.tensor(ic, requires_grad=True) #<-- redundant
optimizer = optim.RMSprop([ic], lr=1e-2)
for itr in range(1, 50):
optimizer.zero_grad()
sol = torch.tanh(.5 * torch.stack(100 * [ic])) # simplified for minimal working example
dx = sol[-1, :, 0]
dxdxy, = autograd.grad(dx,
inputs=ic,
grad_outputs=torch.ones(ic.shape[0]), # batchwise
retain_graph=True, create_graph=True # <-- important
)
#dxdxy = torch.tensor(dxdxy, requires_grad=True) #<-- won't do the trick. Remove
loss = torch.sum(dxdxy)
loss.backward()
optimizer.step()
if itr % 5 == 0:
print(loss)
Upvotes: 3