Penguin
Penguin

Reputation: 2411

What is the correct way to update an input variable during training?

I have an input

inp = torch.tensor([1.0])

and a neural network

class Model_updater(nn.Module):
    def __init__(self):
        super(Model_updater, self).__init__()
        self.fc1 = nn.Linear(1, 2)
        self.fc2 = nn.Linear(2, 3)
        self.fc3 = nn.Linear(3, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net_updater = Model_updater()

opt_updater = optim.Adam(net_updater.parameters())

I'm trying to update my input using the neural network's output:

inp = torch.tensor([1.0])
epochs = 3

for i in range(epochs):
    opt_updater.zero_grad()

    inp_copy = inp.detach().clone()

    mu, sigma = net_updater(inp_copy)
    dist1 = Normal(mu, torch.abs(sigma))
    a = dist1.rsample()

    inp += a

    loss = torch.tensor(5.0) - inp

    loss.backward(retain_graph=True)
    opt_updater.step()

But getting the error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3, 2]], which is output 0 of TBackward, is at version 2; expected version 1

I also tried changing the loss calculations with

loss = torch.tensor(5.0) - inp_copy

But got the error

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I also tried without the retain_graph=True but I get

RuntimeError: Trying to backward through the graph a second time, 
but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward() or autograd.grad() the first time.

Which doesn't really makes sense to me because I don't see where I'm calling backward() twice

Upvotes: 0

Views: 406

Answers (1)

ayandas
ayandas

Reputation: 2268

Most likely, this is what you want

inp1 = inp + a  # create a separate variable for updated value
inp.data = inp1.data # update the value without touching the graph

loss = torch.tensor(5.0) - inp1 # use updated value which has gradient

Upvotes: 1

Related Questions