Reputation: 40969
I'm following a PyTorch tutorial which uses the BERT NLP model (feature extractor) from the Huggingface Transformers library. There are two pieces of interrelated code for gradient updates that I don't understand.
(1) torch.no_grad()
The tutorial has a class where the forward()
function creates a torch.no_grad()
block around a call to the BERT feature extractor, like this:
bert = BertModel.from_pretrained('bert-base-uncased')
class BERTGRUSentiment(nn.Module):
def __init__(self, bert):
super().__init__()
self.bert = bert
def forward(self, text):
with torch.no_grad():
embedded = self.bert(text)[0]
(2) param.requires_grad = False
There is another portion in the same tutorial where the BERT parameters are frozen.
for name, param in model.named_parameters():
if name.startswith('bert'):
param.requires_grad = False
When would I need (1) and/or (2)?
Additionaly, I ran all four combinations and found:
with torch.no_grad requires_grad = False Parameters Ran
------------------ --------------------- ---------- ---
a. Yes Yes 3M Successfully
b. Yes No 112M Successfully
c. No Yes 3M Successfully
d. No No 112M CUDA out of memory
Can someone please explain what's going on? Why am I getting CUDA out of memory
for (d) but not (b)? Both have 112M learnable parameters.
Upvotes: 14
Views: 17482
Reputation: 11490
This is an older discussion, which has changed slightly over the years (mainly due to the purpose of with torch.no_grad()
as a pattern. An excellent answer that kind of answers your question as well can be found on Stackoverflow already.
However, since the original question is vastly different, I'll refrain from marking as duplicate, especially due to the second part about the memory.
An initial explanation of no_grad
is given here:
with torch.no_grad()
is a context manager and is used to prevent calculating gradients [...].
requires_grad
on the other hand is used
to freeze part of your model and train the rest [...].
Source again the SO post.
Essentially, with requires_grad
you are just disabling parts of a network, whereas no_grad
will not store any gradients at all, since you're likely using it for inference and not training.
To analyze the behavior of your combinations of parameters, let us investigate what is happening:
a)
and b)
do not store any gradients at all, which means that you have vastly more memory available to you, no matter the number of parameters, since you're not retaining them for a potential backward pass.c)
has to store the forward pass for later backpropagation, however, only a limited number of parameter (3 million) are stored, which makes this still manageable.d)
, however, needs to store the forward pass for all 112 million parameters, which causes you to run out of memory.Upvotes: 13