Reputation: 16301
I'm using a BERT model for Extractive QA task with the transformers
class library BertForQuestionAnswering
. Extractive Question Answering is the task of answering a question for a given context text and outputting the start and end indexes of where the answer matches in the context. My code is the following:
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased',
cache_dir=os.getenv("cache_dir", "../../models"))
question = "What is the capital of Italy?"
text = "The capital of Italy is Rome."
inputs = tokenizer.encode_plus(question, text, return_tensors='pt')
start, end = model(**inputs)
start_max = torch.argmax(F.softmax(start, dim = -1))
end_max = torch.argmax(F.softmax(end, dim = -1)) + 1 ## add one ##because of python list indexing
answer = tokenizer.decode(inputs["input_ids"][0][start_max : end_max])
print(answer)
I get this error
start_max = torch.argmax(F.softmax(start, dim = -1))
File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 1583, in softmax
ret = input.softmax(dim)
AttributeError: 'str' object has no attribute 'softmax'
I have also tried this approach, slightly different
encoding = tokenizer.encode_plus(text=question,text_pair=text, add_special=True)
inputs = encoding['input_ids'] #Token embeddings
sentence_embedding = encoding['token_type_ids'] #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens
start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))
start_index = torch.argmax(start_scores)
end_index = torch.argmax(end_scores)
answer = ' '.join(tokens[start_index:end_index+1])
but the error is likely the same:
start_index = torch.argmax(start_scores)
TypeError: argmax(): argument 'input' (position 1) must be Tensor, not str
I assume due to the unpack of the output as
start, end = model(**inputs)
If so, how to correct unpack this model's outputs?
Upvotes: 1
Views: 1085
Reputation: 26
Due to version update, the model returns a dictionary and not a tuple of start, end.
You can add the following parameter:
return_dict=False
Upvotes: 1