Reputation: 91
I want to fine tune LabSE for Question answering using squad dataset. and i got this error:
ValueError: The model did not return a loss from the inputs, only the following keys: last_hidden_state,pooler_output. For reference, the inputs it received are input_ids,token_type_ids,attention_mask.
I am trying to fine tune the model using pytorch. I tried to use smaller batch size and i took just 10% of training dataset because i had problems with memory allocation. If memory allocation problems are gone this error happens. To be honest i'm stuck with it. Do you have any hints?
I'm trying to use huggingface tutorial, but i want to use other evaluation (i want to do it myself ) so i skipped using evaluation part of dataset.
from datasets import load_dataset
raw_datasets = load_dataset("squad", split='train')
from transformers import BertTokenizerFast, BertModel
from transformers import AutoTokenizer
model_checkpoint = "setu4993/LaBSE"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = BertModel.from_pretrained(model_checkpoint)
max_length = 384
stride = 128
def preprocess_training_examples(examples):
questions = [q.strip() for q in examples["question"]]
inputs = tokenizer(
questions,
examples["context"],
max_length=max_length,
truncation="only_second",
stride=stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
offset_mapping = inputs.pop("offset_mapping")
sample_map = inputs.pop("overflow_to_sample_mapping")
answers = examples["answers"]
start_positions = []
end_positions = []
for i, offset in enumerate(offset_mapping):
sample_idx = sample_map[i]
answer = answers[sample_idx]
start_char = answer["answer_start"][0]
end_char = answer["answer_start"][0] + len(answer["text"][0])
sequence_ids = inputs.sequence_ids(i)
# Find the start and end of the context
idx = 0
while sequence_ids[idx] != 1:
idx += 1
context_start = idx
while sequence_ids[idx] == 1:
idx += 1
context_end = idx - 1
# If the answer is not fully inside the context, label is (0, 0)
if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
start_positions.append(0)
end_positions.append(0)
else:
# Otherwise it's the start and end token positions
idx = context_start
while idx <= context_end and offset[idx][0] <= start_char:
idx += 1
start_positions.append(idx - 1)
idx = context_end
while idx >= context_start and offset[idx][1] >= end_char:
idx -= 1
end_positions.append(idx + 1)
inputs["start_positions"] = start_positions
inputs["end_positions"] = end_positions
return inputs
train_dataset = raw_datasets.map(
preprocess_training_examples,
batched=True,
remove_columns=raw_datasets.column_names,
)
len(raw_datasets), len(train_dataset)
from transformers import TrainingArguments
args = TrainingArguments(
"bert-finetuned-squad",
save_strategy="epoch",
learning_rate=2e-5,
num_train_epochs=3,
weight_decay=0.01,
)
from transformers import Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
trainer.train()
Upvotes: 8
Views: 17855
Reputation: 4797
maybe datasets set_format
is called, but missing label
.
like this:
sth_encoded.set_format("torch", columns = ["input_ids", "attention_mask"])
It should be
sth_encoded.set_format("torch", columns = ["input_ids", "attention_mask", "label"])
Upvotes: 0
Reputation: 21
When my train, test data had these column names:
Class Index for labels column, Description for text column,
I got the same error but renaming the column names to:
labels and text
I worked without any errors!
Upvotes: 1
Reputation: 81
Rename columns to text
and labels
for text classification using distilbert-base-uncased
model. It needs to be checked for other domains and models too.
Upvotes: 4
Reputation: 103
Hi,
Please make sure you are good with the below :
For example : with BertForQuestionAnswering model, at huggingface github we can see we need start_positions and end_positions as key/column_name, which is what gets accepted by model during forward pass.
Let me know if you or someone is able to resolve the error with the mentioned fix!
Thanks!
Upvotes: 9