Shashwat
Shashwat

Reputation: 449

I am looking for a comprehensive explanation of the `inputs` parameter of the `.backward()` method in PyTorch

I am having trouble understanding the usage of the inputs keyword in the .backward() call.

The Documentation says the following:

inputs (sequence of Tensor) – Inputs w.r.t. which the gradient will be accumulated into .grad. All other Tensors will be ignored. If not provided, the gradient is accumulated into all the leaf Tensors that were used to compute the attr::tensors.

From what I understand this allows us to specify the inputs against which we'll look at gradients.

Isn't the already specified if .backward() is called some tensor like a loss, loss.backward()? wouldn't the computation graph ensure that gradients are calculated with respect to the relevant parameters.

I haven't found sources that explain this better. I'd appreciate if I could be directed to an explanation.

Upvotes: 6

Views: 478

Answers (1)

cheersmate
cheersmate

Reputation: 2656

It simply is a way to limit the set of parameters for which gradients are calculated (as stated in the doc). Here is an example:

import torch
from torch import nn

class Model(nn.Module):
    # warning: useless model for explanatory purpose only
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.as_tensor(1.))
        self.b = nn.Parameter(torch.as_tensor(1.))

    def forward(self, x):
        return x + self.a + 2 * self.b

model = Model()

print(f'no gradients accumulated so far: {model.a.grad}, {model.b.grad}')

loss = model(1)**2
loss.backward()
print(f'gradients after calling loss.backward(): {model.a.grad}, {model.b.grad}')

model.zero_grad()
print('gradients reset by model.zero_grad()')

loss = model(1)**2
loss.backward(inputs=[model.a])
print(f'gradients after calling loss.backward(inputs=[model.a]): {model.a.grad}, {model.b.grad}')

Output:

no gradients accumulated so far: None, None
gradients after calling loss.backward(): 8.0, 16.0
gradients reset by model.zero_grad()
gradients after calling loss.backward(inputs=[model.a]): 8.0, 0.0

Upvotes: 7

Related Questions