Reputation: 21625
Consider this very simple implementation of gradient descent, whereby I attempt to fit a linear regression (mx + b) to some toy data.
import torch
# Make some data
torch.manual_seed(0)
X = torch.rand(35) * 5
Y = 3 * X + torch.rand(35)
# Initialize m and b
m = torch.rand(size=(1,), requires_grad=True)
b = torch.rand(size=(1,), requires_grad=True)
# Pass 1
yhat = X * m + b # Calculate yhat
loss = torch.sqrt(torch.mean((yhat - Y)**2)) # Calculate the loss
loss.backward() # Reverse mode differentiation
m = m - 0.1*m.grad # update m
b = b - 0.1*b.grad # update b
m.grad = None # zero out m gradient
b.grad = None # zero out b gradient
# Pass 2
yhat = X * m + b # Calculate yhat
loss = torch.sqrt(torch.mean((yhat - Y)**2)) # Calculate the loss
loss.backward() # Reverse mode differentiation
m = m - 0.1*m.grad # ERROR
The first pass works fine, but the second pass errors on the last line, m = m - 0.1*m.grad
.
/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:417.)
return self._grad
My understanding of why this happens is that, during Pass 1, this line
m = m - 0.1*m.grad
copies m
into a brand new tensor (i.e. a totally separate block of memory). So, it goes from being a leaf tensor to a non-leaf tensor.
# Pass 1
...
print(f"{m.is_leaf}") # True
m = m - 0.1*m.grad
print(f"{m.is_leaf}") # False
I've seen it mentioned that one could use something along the lines of m.data = m - 0.1*m.grad
, but I haven't seen much discussion about this technique.
Upvotes: 0
Views: 961
Reputation: 40638
You're observation is correct, in order to perform the update you should:
Apply the modification with in-place operators.
Wrap the calls with torch.no_grad
context manager.
For instance:
with torch.no_grad():
m -= 0.1*m.grad # update m
b -= 0.1*b.grad # update b
Upvotes: 1