Reputation: 83
I'm a student and a beginner in Python and PyTorch both. I have a very basic Neural Network for which I am encountering the mentioned RunTimeError. The code to reproduce the error is this:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
# Ensure Reproducibility
torch.manual_seed(0)
# Data Generation
x = torch.randn((100,1), requires_grad = True)
y = 1 + 2 * x + 0.3 * torch.randn(100,1)
# Shuffles the indices
idx = np.arange(100)
np.random.shuffle(idx)
# Uses first 80 random indices for train
train_idx = idx[:70]
# Uses the remaining indices for validation
val_idx = idx[70:]
# Generates train and validation sets
x_train, y_train = x[train_idx], y[train_idx]
x_val, y_val = x[val_idx], y[val_idx]
class OurFirstNeuralNetwork(nn.Module):
def __init__(self):
super(OurFirstNeuralNetwork, self).__init__()
# Here we "define" our Neural Network Architecture
self.fc1 = nn.Linear(1, 5)
self.non_linearity_fc1 = nn.ReLU()
self.fc2 = nn.Linear(5,1)
#self.non_linearity_fc2 = nn.ReLU()
def forward(self, x):
# The forward pass
# Here we define how activations "flow" between neurons. We've already discussed the "Sum" and "Transformation" steps of the forward pass.
sum_fc1 = self.fc1(x)
transformation_fc1 = self.non_linearity_fc1(sum_fc1)
sum_fc2 = self.fc2(transformation_fc1)
#transformation_fc2 = self.non_linearity_fc2(sum_fc2)
# The transformation_fc2 is also the output of our model which symbolises the end of our forward pass.
return sum_fc2
# Instantiate the model and train
model = OurFirstNeuralNetwork()
print(model)
print(model.state_dict())
n_epochs = 1000
loss_fn = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters())
for epoch in range(n_epochs):
model.train()
optimizer.zero_grad()
prediction = model(x_train)
loss = loss_fn(y_train, prediction)
print(epoch, loss)
loss.backward(retain_graph=True)
optimizer.step()
print(model.state_dict())
Everything is basic and standard and this works fine.
However, when I take out the "retain_graph=True" argument, it throws the RunTimeError. From reading various forums, I understand that this is to do with the graph getting thrown away after the first iteration but I have seen many tutorials and blogs where loss.backward()
is the way to go especially since it conserves memory. But I am not able to conceptually grasp why the same does not work for me.
Any help is appreciated and my apologies if the way in which I have asked my question is not in the expected format. I am open to feedback and will oblige to include more details or rephrase the question so that it is easier for everyone. Thank you in advance!
Upvotes: 3
Views: 7848
Reputation: 16440
You need to add optimizer.zero_grad()
after optimizer.step()
to zero out the gradients.
Why you need to do this?
When you do loss.backward()
torch will compute gradients for parameters and update the parameter's .grad
property. When you do optimizer.step()
, the parameters are updated using the .grad
property as i.e `parameter = parameter - lr*parameter.grad.
Since you do not clear the gradients and call backward the second time, it will compute dl/d(updated param)
which will require to backpropagate through paramter.grad
of the first pass. When doing backward, the computation graph of this gradients is not stored and hence you have to pass retain_graph= True
to get rid of error. However, we don't want to do that for updating params. Rather we want to clear gradients, and restart with a new computation graph therefore, you need to zero the gradients with a .zero_grad
call.
Also see Why do we need to call zero_grad() in PyTorch?
Upvotes: 8