Reputation: 2431
I have a model:
import torch
import torch.nn as nn
import torch.optim as optim
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(1, 3)
self.fc2 = nn.Linear(3, 2)
self.fc3 = nn.Linear(2, 1)
def forward(self, x):
x1 = self.fc1(x)
x = torch.relu(x1)
x2 = self.fc2(x)
x = torch.relu(x2)
x3 = self.fc3(x)
return x3, x2, x1
net = Model()
I'm trying to manually update the parameters with
i, j = torch.meshgrid(torch.arange(3), torch.arange(2))
i = i.reshape(-1)
j = j.reshape(-1)
update = torch.ones(6,1)
print(i)
print(j)
print(update.squeeze())
print(net.fc2.weight[j,i].data)
net.fc2.weight[j,i].data += update.squeeze()
print(net.fc2.weight[j,i].data)
>>> tensor([0, 0, 1, 1, 2, 2])
tensor([0, 1, 0, 1, 0, 1])
tensor([1., 1., 1., 1., 1., 1.])
tensor([-0.0209, -0.3770, 0.4982, -0.2123, -0.2630, -0.5580])
tensor([-0.0209, -0.3770, 0.4982, -0.2123, -0.2630, -0.5580])
But nothing seems to change.
However, if I do
print(net.fc2.weight[1].data)
net.fc2.weight[1].data += 1
print(net.fc2.weight[1].data)
>>> tensor([-0.3770, -0.2123, -0.5580])
tensor([0.6230, 0.7877, 0.4420])
They do change.
What am I doing wrong in the first approach and how can I make it work?
Upvotes: 0
Views: 1361
Reputation: 2268
The point you are missing is simple: when you do a "constant indexing", you get a "view" of the tensor, otherwise (i.e. indexing with another tensor) you get a new tensor or a new node in the computation graph.
PyTorch
provides a .data_ptr()
method to peek into the underlying memory pointer.
>> net.fc2.weight.data.data_ptr()
2911054070464
>> net.fc2.weight[1].data.data_ptr()
2911054070464
Constant indexing did not change the underlying raw data. However, indexing with a tensor creates a new node and hence a new underlying raw memory location
>> net.fc2.weight[j, i].data.data_ptr()
2911054068672
So, in your case, you are creating a new tensor/node with net.fc2.weight[j,i]
and assigning new value to it. That's why your original tensor remains unchanged. In the constant indexing case, you are changing the same memory location, hence the change is reflected.
To fix your problem, instead of doing this
net.fc2.weight[j,i].data += update.squeeze()
do this
net.fc2.weight.data[j,i] += update.squeeze()
.. essentially grabbing the underlying .data
first and then indexing it, which means the indexing operation is entirely out of autograd's tracking machinery.
Upvotes: 2