Reputation: 4545
I am using Huggingface to implement a BERT model using BertForSequenceClassification.from_pretrained()
.
The model is trying to predict 1 of 24 classes. I am using a batch size of 32 and a sequence length of 66.
When I try to call the model in training, I get the following error:
ValueError: Expected input batch_size (32) to match target batch_size (768).
However, my target shape is 32x24. It seems like somewhere when the model is called, this is being flattened to 768x1. Here is a test I ran to check:
for i in train_dataloader:
i = tuple(t.to(device) for t in i)
print(i[0].shape, i[1].shape, i[2].shape) # here i[2].shape is (32, 24)
output = model(i[0], attention_mask=i[1], labels=i[2]) # here PyTorch complains that i[2]'s shape is now (768, 1)
print(output.logits.shape)
break
This outputs:
torch.Size([32, 66]) torch.Size([32, 66]) torch.Size([32, 24])
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-68-c69db6168cc3> in <module>
2 i = tuple(t.to(device) for t in i)
3 print(i[0].shape, i[1].shape, i[2].shape)
----> 4 output = model(i[0], attention_mask=i[1], labels=i[2])
5 print(output.logits.shape)
6 break
4 frames
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3024 if size_average is not None or reduce is not None:
3025 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3026 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
3027
3028
ValueError: Expected input batch_size (32) to match target batch_size (768).
Upvotes: 1
Views: 1376
Reputation: 3938
Pytorch's implementation of CrossEntropyLoss expects targets to be integer indices, not one-hot class vectors. Thus target
should be of size [batch_size]
, not [batch_size,n_classes]
.
You can ravel your classes quite simply as follows (provided each class vector is indeed one-hot):
raveler = torch.arange(0,n_classes).unsqueeze(0).expand(batch_size,n_classes)
target = (target * raveler).sum(dim = 1)
Upvotes: 2