
Reputation: 11

ValueError: Target size (torch.Size([8])) must be the same as input size (torch.Size([8, 2]))

I'm trying to implement a code for sentiment analysis( positive or negative labels) using BERT and i want to add a BiLSTM layer to see if I can increase the accuracy of the pretrained model from HuggingFace. I have the below code and a few questions :

import numpy as np
import pandas as pd
from sklearn import metrics
import transformers
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertModel, BertConfig
from torch import cuda
import re
import torch.nn as nn

device = 'cuda' if cuda.is_available() else 'cpu'
MAX_LEN = 200
LEARNING_RATE = 1e-05 #5e-5, 3e-5 or 2e-5
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class CustomDataset(Dataset):
 def __init__(self, dataframe, tokenizer, max_len):
  self.tokenizer = tokenizer
  self.data = dataframe
  self.comment_text = dataframe.review
  self.targets = self.data.sentiment
  self.max_len = max_len
 def __len__(self):
  return len(self.comment_text)
 def __getitem__(self, index):
  comment_text = str(self.comment_text[index])
  comment_text = " ".join(comment_text.split())

  inputs = self.tokenizer.encode_plus(comment_text,None,add_special_tokens=True,max_length=self.max_len,
  ids = inputs['input_ids']
  mask = inputs['attention_mask']
  token_type_ids = inputs["token_type_ids"]

  return {
   'ids': torch.tensor(ids, dtype=torch.long),
   'mask': torch.tensor(mask, dtype=torch.long),
   'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
   'targets': torch.tensor(self.targets[index], dtype=torch.float)
train_size = 0.8
train_dataset = train_dataset.reset_index(drop=True)

print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

training_set = CustomDataset(train_dataset, tokenizer, MAX_LEN)
testing_set = CustomDataset(test_dataset, tokenizer, MAX_LEN)
train_params = {'batch_size': TRAIN_BATCH_SIZE,'shuffle': True,'num_workers': 0}
test_params = {'batch_size': VALID_BATCH_SIZE,'shuffle': True,'num_workers': 0}
training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

class BERTClass(torch.nn.Module):
 def __init__(self):
   super(BERTClass, self).__init__()
   self.bert = BertModel.from_pretrained('bert-base-uncased',return_dict=False, num_labels =2)
   self.lstm = nn.LSTM(768, 256, batch_first=True, bidirectional=True)
   self.linear = nn.Linear(256*2,2)

 def forward(self, ids , mask,token_type_ids):
  sequence_output, pooled_output = self.bert(ids, attention_mask=mask, token_type_ids = token_type_ids)
  lstm_output, (h, c) = self.lstm(sequence_output)  ## extract the 1st token's embeddings
  hidden = torch.cat((lstm_output[:, -1, :256], lstm_output[:, 0, 256:]), dim=-1)
  linear_output = self.linear(lstm_output[:, -1].view(-1, 256 * 2))

  return linear_output

model = BERTClass()
def loss_fn(outputs, targets):
 return torch.nn.BCEWithLogitsLoss()(outputs, targets)
optimizer = torch.optim.Adam(params =  model.parameters(), lr=LEARNING_RATE)

def train(epoch):
 for _, data in enumerate(training_loader, 0):
  ids = data['ids'].to(device, dtype=torch.long)
  mask = data['mask'].to(device, dtype=torch.long)
  token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
  targets = data['targets'].to(device, dtype=torch.float)
  outputs = model(ids, mask, token_type_ids)
  loss = loss_fn(outputs, targets)
  if _ % 5000 == 0:
   print(f'Epoch: {epoch}, Loss:  {loss.item()}')

for epoch in range(EPOCHS):

So on the above code I ran into the error : Target size (torch.Size([8])) must be the same as input size (torch.Size([8, 2])) . Checked online and tried to use targets = targets.unsqueeze(2) but then I get another error that I must use values from [-2,1] for unsqueeze. I also tried to modify the loss function to

def loss_fn(outputs, targets):
 return torch.nn.BCELoss()(outputs, targets)

but I still receive the same error. Can someone advise if there is a solution to this problem? Or what can I do to make this work fine? Many thanks in advance.

Upvotes: 0

Views: 1447

Answers (0)

Related Questions