Ben
Ben

Reputation: 21625

What's the proper way to update a leaf tensor's values (e.g. during the update step of gradient descent)

Toy Example

Consider this very simple implementation of gradient descent, whereby I attempt to fit a linear regression (mx + b) to some toy data.

import torch

# Make some data
torch.manual_seed(0)
X = torch.rand(35) * 5
Y = 3 * X + torch.rand(35)

# Initialize m and b
m = torch.rand(size=(1,), requires_grad=True)
b = torch.rand(size=(1,), requires_grad=True)

# Pass 1
yhat = X * m + b    #  Calculate yhat
loss = torch.sqrt(torch.mean((yhat - Y)**2)) # Calculate the loss
loss.backward()     # Reverse mode differentiation
m = m - 0.1*m.grad  # update m
b = b - 0.1*b.grad  # update b
m.grad = None       # zero out m gradient
b.grad = None       # zero out b gradient

# Pass 2
yhat = X * m + b    #  Calculate yhat
loss = torch.sqrt(torch.mean((yhat - Y)**2)) # Calculate the loss
loss.backward()     # Reverse mode differentiation
m = m - 0.1*m.grad  # ERROR

The first pass works fine, but the second pass errors on the last line, m = m - 0.1*m.grad.

Error

/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:417.)
  return self._grad

My understanding of why this happens is that, during Pass 1, this line

m = m - 0.1*m.grad

copies m into a brand new tensor (i.e. a totally separate block of memory). So, it goes from being a leaf tensor to a non-leaf tensor.

# Pass 1
...
print(f"{m.is_leaf}")  # True
m = m - 0.1*m.grad  
print(f"{m.is_leaf}")  # False

So, how does one perform an update?

I've seen it mentioned that one could use something along the lines of m.data = m - 0.1*m.grad, but I haven't seen much discussion about this technique.

Upvotes: 0

Views: 961

Answers (1)

Ivan
Ivan

Reputation: 40638

You're observation is correct, in order to perform the update you should:

  1. Apply the modification with in-place operators.

  2. Wrap the calls with torch.no_grad context manager.

For instance:

with torch.no_grad():
    m -= 0.1*m.grad  # update m
    b -= 0.1*b.grad  # update b

Upvotes: 1

Related Questions