Reputation: 3
I am working with a training deep learning model with the Pytorch framework. And I add torch.no_grad to speed up the training phase
model.train()
for epoch in range(epochs):
for data, label in loader:
data, label = data.to(device), label.to(device)
with torch.no_grad():
out = model(data)
out.requires_grad = True
#model.zero_grad(), loss(), loss.backward, optim.step
The speed is improved, but have something wrong with the gradient update, the model doesn't converge correctly. Can someone explain to me why it doesn't work?
Upvotes: 0
Views: 338
Reputation: 2569
Simply, when using the torch.no_grad
context manager, the gradients are not computed, so the model cannot receive any update.
torch.no_grad
is meant to be used in other cases, for example when evaluating the model. From the docs:
Disabling gradient calculation is useful for inference, when you are sure that you will not call
Tensor.backward()
Upvotes: 1