Deshwal
Deshwal

Reputation: 4162

Train n% last layers of BERT in Pytorch using HuggingFace Library (train Last 5 BERTLAYER out of 12 .)

Bert has an Architecture something like encoder -> 12 BertLayer -> Pooling. I want to train the last 40% layers of Bert Model. I can freeze all the layers as:

# freeze parameters
bert = AutoModel.from_pretrained('bert-base-uncased')
for param in bert.parameters():
    param.requires_grad = False

But I want to Train last 40% layers. When I do len(list(bert.parameters())), it gives me 199. So let us suppose 79 is the 40% of parameters. Can I do something like:

for param in list(bert.parameters())[-79:]: # total  trainable 199 Params: 79 is 40%
    param.requires_grad = False

I think it will freeze first 60% layers.

Also, can someone tell me that which layers it will freeze according to architecture?

Upvotes: 1

Views: 2964

Answers (1)

cronoik
cronoik

Reputation: 19520

You are probably looking for named_parameters.

for name, param in bert.named_parameters():                                            
    print(name)

Output:

embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
embeddings.LayerNorm.bias
encoder.layer.0.attention.self.query.weight
encoder.layer.0.attention.self.query.bias
encoder.layer.0.attention.self.key.weight
...

named_parameters will also show you that you have not frozen the first 60% but the last 40%:

for name, param in bert.named_parameters():
    if param.requires_grad == True:
        print(name) 

Output:

embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
embeddings.LayerNorm.bias
encoder.layer.0.attention.self.query.weight
encoder.layer.0.attention.self.query.bias
encoder.layer.0.attention.self.key.weight
encoder.layer.0.attention.self.key.bias
encoder.layer.0.attention.self.value.weight
...

You can freeze the first 60% with:

for name, param in list(bert.named_parameters())[:-79]: 
    print('I will be frozen: {}'.format(name)) 
    param.requires_grad = False

Upvotes: 11

Related Questions