loretoparisi
loretoparisi

Reputation: 16301

Using transformers class BertForQuestionAnswering for Extractive Question Answering

I'm using a BERT model for Extractive QA task with the transformers class library BertForQuestionAnswering. Extractive Question Answering is the task of answering a question for a given context text and outputting the start and end indexes of where the answer matches in the context. My code is the following:

model = BertForQuestionAnswering.from_pretrained('bert-base-uncased',
    cache_dir=os.getenv("cache_dir", "../../models"))
question = "What is the capital of Italy?"
text = "The capital of Italy is Rome."
inputs = tokenizer.encode_plus(question, text, return_tensors='pt')
start, end = model(**inputs)
start_max = torch.argmax(F.softmax(start, dim = -1))
end_max = torch.argmax(F.softmax(end, dim = -1)) + 1 ## add one ##because of python list indexing
answer = tokenizer.decode(inputs["input_ids"][0][start_max : end_max])
print(answer)

I get this error

start_max = torch.argmax(F.softmax(start, dim = -1))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 1583, in softmax
    ret = input.softmax(dim)
AttributeError: 'str' object has no attribute 'softmax'

I have also tried this approach, slightly different

encoding = tokenizer.encode_plus(text=question,text_pair=text, add_special=True)
inputs = encoding['input_ids']  #Token embeddings
sentence_embedding = encoding['token_type_ids']  #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens
start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))
start_index = torch.argmax(start_scores)
end_index = torch.argmax(end_scores)
answer = ' '.join(tokens[start_index:end_index+1])

but the error is likely the same:

    start_index = torch.argmax(start_scores)
TypeError: argmax(): argument 'input' (position 1) must be Tensor, not str

I assume due to the unpack of the output as

start, end = model(**inputs)

If so, how to correct unpack this model's outputs?

Upvotes: 1

Views: 1085

Answers (1)

rawan qadri
rawan qadri

Reputation: 26

Due to version update, the model returns a dictionary and not a tuple of start, end. You can add the following parameter: return_dict=False

Upvotes: 1

Related Questions