ApPs
ApPs

Reputation: 33

Pytorch autograd: Make gradient of a parameter a function of another parameter

In Pytorch, how can I make the gradient of a parameter a function itself?

Here is a simple code snippet:

import torch

def fun(q):

    def result(w):
        l = w * q
        l.backward()
        return w.grad

    return result

w = torch.tensor((2.), requires_grad=True)
q = torch.tensor((3.), requires_grad=True)

f = fun(q)

print(f(w))

In the code above, how can I make f(w) have gradient with respect to q?

EDIT: based on the accepted answer I was able to write a code that works. Essentially I am alternating between 2 optimization steps. For dim == 1 it works and for dim == 2 it does not. I get the error "RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time."

import torch 

class f_class():
    def __init__(self, dim):    
        self.dim = dim
        if self.dim == 1:      
            self.w = torch.tensor((3.), requires_grad=True)
        elif self.dim == 2:            
            self.w = [torch.tensor((3.), requires_grad=True), torch.tensor((5.), requires_grad=True)]
        else:
            raise ValueError("dim 1 or 2")
        
    def forward(self, x):
        if self.dim == 1:      
            return torch.mul(self.w, x)
        elif self.dim == 2:  
            return torch.mul(torch.mul(self.w[0], self.w[1]), x)            
           
    def set_w(self, w):
        self.w = w
        
    def get_w(self):
        return self.w
    
class g_class():
    def __init__(self):                
        self.q = torch.tensor((4.), requires_grad=True)
        
    def forward(self, f):
        return torch.mul(self.q, f)
    
    def set_q(self, q):
        self.q = q
        
    def get_q(self):
        return self.q
        
def w_new(f, g, dim):  
    
    loss_g = g.forward(f.forward(xd))
    
    if dim == 1:
        grads = torch.autograd.grad(loss_g, f.get_w(), create_graph=True, only_inputs=True)[0]
           
        temp = f.get_w().detach() + grads        
    else:                
        grads = torch.autograd.grad(loss_g, f.get_w(), create_graph=True, only_inputs=True)
        
        temp = [wi.detach() + gi for wi, gi in zip(f.get_w(), grads)] 

    return temp    

def q_new(f, g):          

    loss_f = 2 * f.forward(xd) 
       
    loss_f.backward()

    temp = g.get_q().detach() + g.get_q().grad
    
    temp.requires_grad = True
    
    return temp

dim = 1

xd = torch.tensor((2.))
        
f = f_class(dim)
g = g_class()

for _ in range(3):
    
    print(f.get_w(), g.get_q())
    
    wnew = w_new(f, g, dim)
    
    f.set_w(wnew)
    print(f.get_w(), g.get_q())

    qnew = q_new(f, g)
    
    g.set_q(qnew)
    
print(f.get_w(), g.get_q())

Upvotes: 2

Views: 1925

Answers (1)

jodag
jodag

Reputation: 22184

When computing gradients, if you want to construct a computation graph for the gradient itself you need to specify create_graph=True to autograd.

A potential source of error in your code is using Tensor.backward within f. The problem here is that w.grad and q.grad will be populated with the gradient of l. This means that when you call f(w).backward(), the gradients of both f and l will be added to w.grad and q.grad. In effect you will end up with w.grad being equal to dl/dw + df/dw and similarly for q.grad. One way to get around this is to zero the gradients after f(w) but before .backward(). A better way is to use torch.autograd.grad within f. Using the latter approach, the grad attribute of w and q will not be populated when calling f, only when calling .backward(). This leaves room for things like gradient accumulation during training.

import torch

def fun(q):
    def result(w):
        l = w * q 
        return torch.autograd.grad(l, w, only_inputs=True, retain_graph=True)[0]
    return result


w = torch.tensor((2.), requires_grad=True)
q = torch.tensor((3.), requires_grad=True)

f = fun(q)

f(w).backward()

print('w.grad:', w.grad)
print('q.grad:', q.grad)

which results in

w.grad: None
q.grad: tensor(1.)

Note that w.grad was not populated. This is because f(w) = dl/dw = q is not a function of w, and therefore w is not part of the computation graph. If you're using a standard pytorch optimizer this is fine since None gradients are implicitly assumed to be zero.

If l were instead a non-linear function of w, then w.grad would have been populated after f(w).backward(). For example

import torch

def fun(q):
    def result(w):
        # now dl/dw = 2 * w * q
        l = w**2 * q
        return torch.autograd.grad(l, w, only_inputs=True, create_graph=True)[0]
    return result

w = torch.tensor((2.), requires_grad=True)
q = torch.tensor((3.), requires_grad=True)

f = fun(q)

f(w).backward()

print('w.grad:', w.grad)
print('q.grad:', q.grad)

which results in

w.grad: tensor(6.)
q.grad: tensor(4.)

Upvotes: 2

Related Questions