Reputation: 317
I'm working with a linear regression example in PyTorch. I know I did wrong including 'loss.backward()' in 'with torch.no_grad():', but why it worked well with my code?
According to pytorch docs, torch.autograd.no_grad
is a context-manager that disabled gradient calculation. So I'm really confused.
Code here:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# Toy dataset
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
[9.779], [6.182], [7.59], [2.167], [7.042],
[10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
[3.366], [2.596], [2.53], [1.221], [2.827],
[3.465], [1.65], [2.904], [1.3]], dtype=np.float32)
input_size = 1
output_size = 1
epochs = 100
learning_rate = 0.05
model = nn.Linear(input_size, output_size)
criterion = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# training
for epoch in range(epochs):
# convert numpy to tensor
inputs = torch.from_numpy(x_train)
targets = torch.from_numpy(y_train)
# forward
out = model(inputs)
loss = criterion(out, targets)
# backward
with torch.no_grad():
model.zero_grad()
loss.backward()
optimizer.step()
print('inputs grad : ', inputs.requires_grad)
if epoch % 5 == 0:
print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))
predicted = model(torch.from_numpy(x_train)).detach().numpy()
plt.plot(x_train, y_train, 'ro', label='Original data')
plt.plot(x_train, predicted, label='Fitted line')
plt.legend()
plt.show()
# Save the model checkpoint
torch.save(model.state_dict(), 'model\linear_model.ckpt')
Thanks in advance for answering my question.
Upvotes: 1
Views: 1979
Reputation: 1931
This worked because the loss calculation has happened before the no_grad
and you keep calculating the gradients according to that loss calculation (which calculation had gradient enabled).
Basically, you continue update the weights of your layers using the gradients calculated outside of the no_grad
.
When you actually use the no_grad
:
for epoch in range(epochs):
# convert numpy to tensor
inputs = torch.from_numpy(x_train)
targets = torch.from_numpy(y_train)
with torch.no_grad(): # no_grad used here
# forward
out = model(inputs)
loss = criterion(out, targets)
model.zero_grad()
loss.backward()
optimizer.step()
print('inputs grad : ', inputs.requires_grad)
if epoch % 5 == 0:
print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))
Then you will get the proper error, saying:
element 0 of tensors does not require grad and does not have a grad_fn
.
That is, you use no_grad
where is not appropriate to use it.
If you print the .requires_grad
of loss, then you will see that loss has requires_grad
.
That is, when you do this:
for epoch in range(epochs):
# convert numpy to tensor
inputs = torch.from_numpy(x_train)
targets = torch.from_numpy(y_train)
# forward
out = model(inputs)
loss = criterion(out, targets)
# backward
with torch.no_grad():
model.zero_grad()
loss.backward()
optimizer.step()
print('inputs grad : ', inputs.requires_grad)
print('loss grad : ', loss.requires_grad) # Prints loss.require_rgad
if epoch % 5 == 0:
print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))
You will see:
inputs grad : False
loss grad : True
Additionally, the
print('inputs grad : ', inputs.requires_grad)
Will always print False
. That is, if you do
for epoch in range(epochs):
# convert numpy to tensor
inputs = torch.from_numpy(x_train)
targets = torch.from_numpy(y_train)
print('inputs grad : ', inputs.requires_grad). # Print the inputs.requires_grad
# forward
out = model(inputs)
loss = criterion(out, targets)
# backward
with torch.no_grad():
model.zero_grad()
loss.backward()
optimizer.step()
print('inputs grad : ', inputs.requires_grad)
print('loss grad : ', loss.requires_grad)
if epoch % 5 == 0:
print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))
You will get:
inputs grad : False
inputs grad : False
loss grad : True
That is, you are using wrong things to check what you did wrong. The best thing that you can do is to read again the docs of PyTorch on gradient mechanics.
Upvotes: 8