Reputation: 33
In Pytorch, how can I make the gradient of a parameter a function itself?
Here is a simple code snippet:
import torch
def fun(q):
def result(w):
l = w * q
l.backward()
return w.grad
return result
w = torch.tensor((2.), requires_grad=True)
q = torch.tensor((3.), requires_grad=True)
f = fun(q)
print(f(w))
In the code above, how can I make f(w) have gradient with respect to q?
EDIT: based on the accepted answer I was able to write a code that works. Essentially I am alternating between 2 optimization steps. For dim == 1 it works and for dim == 2 it does not. I get the error "RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time."
import torch
class f_class():
def __init__(self, dim):
self.dim = dim
if self.dim == 1:
self.w = torch.tensor((3.), requires_grad=True)
elif self.dim == 2:
self.w = [torch.tensor((3.), requires_grad=True), torch.tensor((5.), requires_grad=True)]
else:
raise ValueError("dim 1 or 2")
def forward(self, x):
if self.dim == 1:
return torch.mul(self.w, x)
elif self.dim == 2:
return torch.mul(torch.mul(self.w[0], self.w[1]), x)
def set_w(self, w):
self.w = w
def get_w(self):
return self.w
class g_class():
def __init__(self):
self.q = torch.tensor((4.), requires_grad=True)
def forward(self, f):
return torch.mul(self.q, f)
def set_q(self, q):
self.q = q
def get_q(self):
return self.q
def w_new(f, g, dim):
loss_g = g.forward(f.forward(xd))
if dim == 1:
grads = torch.autograd.grad(loss_g, f.get_w(), create_graph=True, only_inputs=True)[0]
temp = f.get_w().detach() + grads
else:
grads = torch.autograd.grad(loss_g, f.get_w(), create_graph=True, only_inputs=True)
temp = [wi.detach() + gi for wi, gi in zip(f.get_w(), grads)]
return temp
def q_new(f, g):
loss_f = 2 * f.forward(xd)
loss_f.backward()
temp = g.get_q().detach() + g.get_q().grad
temp.requires_grad = True
return temp
dim = 1
xd = torch.tensor((2.))
f = f_class(dim)
g = g_class()
for _ in range(3):
print(f.get_w(), g.get_q())
wnew = w_new(f, g, dim)
f.set_w(wnew)
print(f.get_w(), g.get_q())
qnew = q_new(f, g)
g.set_q(qnew)
print(f.get_w(), g.get_q())
Upvotes: 2
Views: 1925
Reputation: 22184
When computing gradients, if you want to construct a computation graph for the gradient itself you need to specify create_graph=True
to autograd.
A potential source of error in your code is using Tensor.backward
within f
. The problem here is that w.grad
and q.grad
will be populated with the gradient of l
. This means that when you call f(w).backward()
, the gradients of both f
and l
will be added to w.grad
and q.grad
. In effect you will end up with w.grad
being equal to dl/dw + df/dw
and similarly for q.grad
. One way to get around this is to zero the gradients after f(w)
but before .backward()
. A better way is to use torch.autograd.grad
within f
. Using the latter approach, the grad
attribute of w
and q
will not be populated when calling f
, only when calling .backward()
. This leaves room for things like gradient accumulation during training.
import torch
def fun(q):
def result(w):
l = w * q
return torch.autograd.grad(l, w, only_inputs=True, retain_graph=True)[0]
return result
w = torch.tensor((2.), requires_grad=True)
q = torch.tensor((3.), requires_grad=True)
f = fun(q)
f(w).backward()
print('w.grad:', w.grad)
print('q.grad:', q.grad)
which results in
w.grad: None
q.grad: tensor(1.)
Note that w.grad
was not populated. This is because f(w) = dl/dw = q
is not a function of w
, and therefore w
is not part of the computation graph. If you're using a standard pytorch optimizer this is fine since None
gradients are implicitly assumed to be zero.
If l
were instead a non-linear function of w
, then w.grad
would have been populated after f(w).backward()
. For example
import torch
def fun(q):
def result(w):
# now dl/dw = 2 * w * q
l = w**2 * q
return torch.autograd.grad(l, w, only_inputs=True, create_graph=True)[0]
return result
w = torch.tensor((2.), requires_grad=True)
q = torch.tensor((3.), requires_grad=True)
f = fun(q)
f(w).backward()
print('w.grad:', w.grad)
print('q.grad:', q.grad)
which results in
w.grad: tensor(6.)
q.grad: tensor(4.)
Upvotes: 2