madman_with_a_box
madman_with_a_box

Reputation: 84

How to compute gradient of the error with respect to the model input?

Given a simple 2 layer neural network, the traditional idea is to compute the gradient w.r.t. the weights/model parameters. For an experiment, I want to compute the gradient of the error w.r.t the input. Are there existing Pytorch methods that can allow me to do this?

More concretely, consider the following neural network:

import torch.nn as nn
import torch.nn.functional as F

class NeuralNet(nn.Module):
    def __init__(self, n_features, n_hidden, n_classes, dropout):
        super(NeuralNet, self).__init__()

        self.fc1 = nn.Linear(n_features, n_hidden)
        self.sigmoid = nn.Sigmoid()
        self.fc2 = nn.Linear(n_hidden, n_classes)
        self.dropout = dropout

    def forward(self, x):
        x = self.sigmoid(self.fc1(x))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

I instantiate the model and an optimizer for the weights as follows:

import torch.optim as optim
model = NeuralNet(n_features=args.n_features,
            n_hidden=args.n_hidden,
            n_classes=args.n_classes,
            dropout=args.dropout)
optimizer_w = optim.SGD(model.parameters(), lr=0.001)

While training, I update the weights as usual. Now, given that I have values for the weights, I should be able to use them to compute the gradient w.r.t. the input. I am unable to figure out how.

def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output = model(features)
    loss_train = F.nll_loss(output[idx_train], labels[idx_train])
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer_w.step()

    # grad_features = loss_train.backward() w.r.t to features
    # features -= 0.001 * grad_features

for epoch in range(args.epochs):
    train(epoch)

Upvotes: 1

Views: 2301

Answers (1)

Jatentaki
Jatentaki

Reputation: 13113

It is possible, just set input.requires_grad = True for each input batch you're feeding in, and then after loss.backward() you should see that input.grad holds the expected gradient. In other words, if your input to the model (which you call features in your code) is some M x N x ... tensor, features.grad will be a tensor of the same shape, where each element of grad holds the gradient with respect to the corresponding element of features. In my comments below, I use i as a generalized index - if your parameters has for instance 3 dimensions, replace it with features.grad[i, j, k], etc.

Regarding the error you're getting: PyTorch operations build a tree representing the mathematical operation they are describing, which is then used for differentiation. For instance c = a + b will create a tree where a and b are leaf nodes and c is not a leaf (since it results from other expressions). Your model is the expression, and its inputs as well as parameters are the leaves, whereas all intermediate and final outputs are not leaves. You can think of leaves as "constants" or "parameters" and of all other variables as of functions of those. This message tells you that you can only set requires_grad of leaf variables.

Your problem is that at the first iteration, features is random (or however else you initialize) and is therefore a valid leaf. After your first iteration, features is no longer a leaf, since it becomes an expression calculated based on the previous ones. In pseudocode, you have

f_1 = initial_value # valid leaf
f_2 = f_1 + your_grad_stuff # not a leaf: f_2 is a function of f_1

to deal with that you need to use detach, which breaks the links in the tree, and makes the autograd treat a tensor as if it was constant, no matter how it was created. In particular, no gradient calculations will be backpropagated through detach. So you need something like

features = features.detach() - 0.01 * features.grad

Note: perhaps you need to sprinkle a couple more detaches here and there, which is hard to say without seeing your whole code and knowing the exact purpose.

Upvotes: 2

Related Questions