Daniel Möller
Daniel Möller

Reputation: 86600

How to assign a new value to a pytorch Variable without breaking backpropagation?

I have a pytorch variable that is used as a trainable input for a model. At some point I need to manually reassign all values in this variable.

How can I do that without breaking the connections with the loss function?

Suppose the current values are [1.2, 3.2, 43.2] and I simply want them to become [1,2,3].


Edit

At the time I asked this question, I hadn't realized that PyTorch doesn't have a static graph as Tensorflow or Keras do.

In PyTorch, the training loop is made manually and you need to call everything in each training step. (There isn't the notion of placeholder + static graph for later feeding data).

Consequently, we can't "break the graph", since we will use the new variable to perform all the further computations again. I was worried about a problem that happens in Keras, not in PyTorch.

Upvotes: 13

Views: 17012

Answers (1)

MBT
MBT

Reputation: 24099

You can use the data attribute of tensors to modify the values, since modifications on data do not affect the graph.
So the graph will still be intact and modifications of the data attribute itself have no influence on the graph. (Operations and changes on data are not tracked by autograd and thus not present in the graph)

Since you haven't given an example, this example is based on your comment statement:
'Suppose I want to change the weights of a layer.'
I used normal tensors here, but this works the same for weight.data and bias.data attributes of a layers.

Here is a short example:

import torch
import torch.nn.functional as F



# Test 1, random vector with CE
w1 = torch.rand(1, 3, requires_grad=True)
loss = F.cross_entropy(w1, torch.tensor([1]))
loss.backward()
print('w1.data', w1)
print('w1.grad', w1.grad)
print()

# Test 2, replacing values of w2 with w1, before CE
# to make sure that everything is exactly like in Test 1 after replacing the values
w2 = torch.zeros(1, 3, requires_grad=True)
w2.data = w1.data
loss = F.cross_entropy(w2, torch.tensor([1]))
loss.backward()
print('w2.data', w2)
print('w2.grad', w2.grad)
print()

# Test 3, replace data after computation
w3 = torch.rand(1, 3, requires_grad=True)
loss = F.cross_entropy(w3, torch.tensor([1]))
# setting values
# the graph of the previous computation is still intact as you can in the below print-outs
w3.data = w1.data
loss.backward()

# data were replaced with values from w1
print('w3.data', w3)
# gradient still shows results from computation with w3
print('w3.grad', w3.grad)

Output:

w1.data tensor([[ 0.9367,  0.6669,  0.3106]])
w1.grad tensor([[ 0.4351, -0.6678,  0.2326]])

w2.data tensor([[ 0.9367,  0.6669,  0.3106]])
w2.grad tensor([[ 0.4351, -0.6678,  0.2326]])

w3.data tensor([[ 0.9367,  0.6669,  0.3106]])
w3.grad tensor([[ 0.3179, -0.7114,  0.3935]])

The most interesting part here is w3. At the time backward is called the values are replaced by values of w1.
But the gradients are calculated based on the CE-function with values of original w3. The replaced values have no effect on the graph. So the graph connection is not broken, replacing had no influence on graph. I hope this is what you were looking for!

Upvotes: 14

Related Questions