Alex
Alex

Reputation: 83

Having 6 labels instead of 2 in Hugging Face BertForSequenceClassification

I was just wondering if it is possibel to extend the HuggingFace BertForSequenceClassification model to more than 2 labels. The docs say, we can pass positional arguments, but it seems like "labels" is not working. Does anybody has an idea?

Model assignment

labels = th.tensor([0,0,0,0,0,0], dtype=th.long).unsqueeze(0)
print(labels.shape)
modelBERTClass = transformers.BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', 
    labels=labels
    )

l = [module for module in modelBERTClass.modules()]
l

Console Output

torch.Size([1, 6])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-122-fea9a36402a6> in <module>()
      3 modelBERTClass = transformers.BertForSequenceClassification.from_pretrained(
      4     'bert-base-uncased',
----> 5     labels=labels
      6     )
      7 

/usr/local/lib/python3.6/dist-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    653 
    654         # Instantiate model.
--> 655         model = cls(config, *model_args, **model_kwargs)
    656 
    657         if state_dict is None and not from_tf:

TypeError: __init__() got an unexpected keyword argument 'labels'

Upvotes: 4

Views: 7346

Answers (2)

cronoik
cronoik

Reputation: 19455

You can set the output shape of the classification layer with from_pretrained via the num_labels parameter:

from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
print(model.classifier.parameters)

Output:

Linear(in_features=768, out_features=6, bias=True)

Upvotes: 7

bpfrd
bpfrd

Reputation: 1025

The above solution isn't totally correct. It is clearly mentioned in the doc that from_config doesn't load the weights. you can confirm that by printing the weights of the embedding layer for example. So you need to load the weights yourself. I would suggest the following:

from transformers import BertForSequenceClassification, BertConfig

#loading pretrained model first
pretrained_model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
pretrained_model.bert.embeddings.word_embeddings.weight.data

>tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]])


config = BertConfig.from_pretrained('bert-base-uncased')
config.num_labels = 6
model = BertForSequenceClassification(config) #or from_config in auto models

model.bert.embeddings.word_embeddings.weight.data

>tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0189,  0.0015, -0.0462,  ...,  0.0228, -0.0095, -0.0070],
        [-0.0137,  0.0101,  0.0365,  ...,  0.0034, -0.0145, -0.0203],
        ...,
        [ 0.0140,  0.0357,  0.0051,  ...,  0.0403,  0.0142,  0.0068],
        [ 0.0132, -0.0024, -0.0015,  ..., -0.0329,  0.0066,  0.0123],
        [ 0.0132,  0.0432,  0.0248,  ..., -0.0017,  0.0158,  0.0155]])


#copying weights for all layers expect for the new classifier
for name, param in list(model.named_parameters())[:-2]:
    param.requires_grad_(False) 
    param.data = pretrained_model.state_dict().get(name).data.clone()

model.bert.embeddings.word_embeddings.weight.data

>tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]])

Upvotes: 2

Related Questions