Peter
Peter

Reputation: 13505

pytorch: how to directly find gradient w.r.t. loss

In theano, it was very easy to get the gradient of some variable w.r.t. a given loss:

loss = f(x, w)
dl_dw = tt.grad(loss, wrt=w)

I get that pytorch goes by a different paradigm, where you'd do something like:

loss = f(x, w)
loss.backwards()
dl_dw = w.grad

The thing is I might not want to do a full backwards propagation through the graph - just along the path needed to get to w.

I know you can define Variables with requires_grad=False if you don't want to backpropagate through them. But then you have to decide that at the time of variable-creation (and the requires_grad=False property is attached to the variable, rather than the call which gets the gradient, which seems odd).

My Question is is there some way to backpropagate on demand (i.e. only backpropagate along the path needed to compute dl_dw, as you would in theano)?

Upvotes: 4

Views: 2700

Answers (1)

Peter
Peter

Reputation: 13505

It turns out that this is reallyy easy. Just use torch.autograd.grad

Example:

import torch
import numpy as np
from torch.autograd import grad

x = torch.autograd.Variable(torch.from_numpy(np.random.randn(5, 4)))
w = torch.autograd.Variable(torch.from_numpy(np.random.randn(4, 3)), requires_grad=True)
y = torch.autograd.Variable(torch.from_numpy(np.random.randn(5, 3)))
loss = ((x.mm(w) - y)**2).sum()
(d_loss_d_w, ) = grad(loss, w)

assert np.allclose(d_loss_d_w.data.numpy(), (x.transpose(0, 1).mm(x.mm(w)-y)*2).data.numpy())

Thanks to JerryLin for answering the question here.

Upvotes: 4

Related Questions