Reputation: 193
I have a neural network which outputs output
. I want to transform output
before the loss and backpropogation happen.
Here is my general code:
with torch.set_grad_enabled(training):
outputs = net(x_batch[:, 0], x_batch[:, 1]) # the prediction of the NN
# My issue is here:
outputs = transform_torch(outputs)
loss = my_loss(outputs, y_batch)
if training:
scheduler.step()
loss.backward()
optimizer.step()
I have a transformation function which I put my output through:
def transform_torch(predictions):
torch_dimensions = predictions.size()
torch_grad = predictions.grad_fn
cuda0 = torch.device('cuda:0')
new_tensor = torch.ones(torch_dimensions, dtype=torch.float64, device=cuda0, requires_grad=True)
for i in range(int(len(predictions))):
a = predictions[i]
# with torch.no_grad(): # Note: no training happens if this line is kept in
new_tensor[i] = torch.flip(torch.cumsum(torch.flip(a, dims = [0]), dim = 0), dims = [0])
return new_tensor
My problem is that I get an error on the next to last line:
RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
Any suggestions? I have already tried using "with torch.no_grad():" (commented), but this results in very poor training and I believe that the gradients don't backpropogate properly after the transformation function.
Thanks!
Upvotes: 1
Views: 318
Reputation: 2268
The error is quite correct about what the issue is - when you create a new tensor with requires_grad = True
, you create a leaf node in the graph (just like parameters of a model) and not allowed to do in-place operation on it.
The solution is simple, you do not need to create the new_tensor
in advance. It is not supposed to be a leaf node; just create it on the fly
new_tensor = [ ]
for i in range(int(len(predictions))):
a = predictions[i]
new_tensor.append(torch.flip(torch.cumsum(torch.flip(a, ...), ...), ...))
new_tensor = torch.stack(new_tensor, 0)
This new_tensor
will inherit all properties like dtype
, device
from predictions
and will have require_grad = True
already.
Upvotes: 2