Reputation: 125
I am using BERT's Huggingface DistilBERT model as a backend for a question and answer application. The text I am using with which to train the model is one very large single text field. Even though the text field is a single string, the punctuation was left in place as a clue for BERT. When I execute the application I am getting the "Token indices sequence length error". I am using the transformer.encodeplus()
method to pass the text into the model. I have tried various mechanisms to truncate the input ids to a length <= to 512.
I am currently using Windows 10 but I will also be porting the code to a Raspberry Pi 4 platform.
The code is failing at this line:
start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))
I am attempting to perform the truncation at this line:
encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True).input_ids)
The entire code is here:
from transformers import AutoTokenizer, DistilBertTokenizer, DistilBertForQuestionAnswering
import torch
# globals - set once used everywhere
tokenizer = None
model = None
context = ''
def establishSettings():
global tokenizer, model, context
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', return_token_type_ids=True, model_max_length=512)
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad', return_dict=False)
# context = "Some 1,500 volcanoes are still considered potentially active around the world today 161 of those over 10 percent sit within the boundaries of the United States."
# get the volcano corpus
with open('volcanic.corpus', encoding="utf8") as file:
context = file.read().replace('\n', '')
print(len(tokenizer(context, truncation=True).input_ids))
def askQuestion(question):
global tokenizer, model, context
print("\nQuestion ", question)
encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True).input_ids)
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))
ans_tokens = input_ids[torch.argmax(start_scores): torch.argmax(end_scores) + 1]
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
#all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
return answer_tokens
def main():
# set the global itmes once
establishSettings()
# ask a question
question = "How many potentially active volcanoes are there in the world today?"
answer_tokens = askQuestion(question)
print("answer_tokens: ", answer_tokens)
if len(answer_tokens) == 0:
answer = "Sorry, I don't have an answer for that one. Ask me another question about New Mexico volcanoes."
print(answer)
else:
answer_tokens_to_string = tokenizer.convert_tokens_to_string(answer_tokens)
print("\nFinal Answer : ")
print(answer_tokens_to_string)
if __name__ == '__main__':
main()
What is the best way to truncate the input.ids
to <= 512 in length.
Upvotes: 0
Views: 2961
Reputation: 1376
Edit this line:
encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True).input_ids)
to
encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True, max_length=512).input_ids)
Upvotes: 3