Reputation: 333
Some of this code was adapted from the book Deep learning with Pytorch
Script: Linear regression (trying to predict t_c given t_u)
t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0, 8.0,
3.0, -4.0, 6.0, 13.0, 21.0])
t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9,
33.9, 21.8, 48.4, 60.4, 68.4])
def model(t_u, w, b):
return w * t_u + b
def loss_fn(t_p, t_c):
squared_diffs = (t_p - t_c)**2
return squared_diffs.mean()
params = torch.tensor([1.0, 0.0], requires_grad=True)
loss = loss_fn(model(t_u, params[0], params[1]), t_c)
loss.backward()
print(params.grad)
Here I am passing in the 0th and 1st index of params
as an input to the model
function, which performs scalar-to-vector multiplication and addition.
My question is, what is PyTorch exactly doing to compute the gradients of the params
tensor? The "feedforward" step uses two subtensors of the params
tensor, rather than separate tensors for bias and weight, which is what I am familiar with.
My guess is: params[0]
and params[1]
are both references to elements in params
, and they both have their own distinct gradients stored somewhere in the params.grad
. So the .backward()
call is treating params[0]
and params[1]
as new individual tensors (as if we temporarily had two separate tensors -weight
and bias
) and updates their gradients (params[0].grad
, params[1].grad
), hence updating the params.grad
since they are references to it.
Upvotes: 0
Views: 314
Reputation: 4826
The main idea here is that the indexing operation returns a new view of the tensor. If you are not using in-place operations (+=, -=, etc.), the "view" thing does not really matter and you can consider it as just another tensor.
In that case, the indexing operation is no different from other operations like addition or matrix-multiplication -- input (original tensor), output (selected tensor), and gradient (1 if selected, zero otherwise*). Then back-propagation happens as usual.
* More specifically, the gradient of an input entry with respect to an output entry is 1 if the output entry is selected from the input entry; 0 otherwise.
EDIT: Maybe it's easier to see it this way:
a = d_params[0]
c = W*a+b
--------------------------
dc/d_params
= dc/{d_params[0], d_params[1], d_params[2], ...}
--------------------------
dc/d_params[0]
= dc/da * da/d_params[0]
= dc/da * 1
= W
--------------------------
dc/d_params[1], dc/d_params[2], ... = 0
Upvotes: 1