Niloofar Adelkhani
Niloofar Adelkhani

Reputation: 227

TypeError in torch.argmax() when want to find the tokens with the highest `start` score

I want to run this code for question answering using hugging face transformers.

import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

#Model
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

#Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

question = '''Why was the student group called "the Methodists?"'''

paragraph = ''' The movement which would become The United Methodist Church began in the mid-18th century within the Church of England.
            A small group of students, including John Wesley, Charles Wesley and George Whitefield, met on the Oxford University campus.
            They focused on Bible study, methodical study of scripture and living a holy life.
            Other students mocked them, saying they were the "Holy Club" and "the Methodists", being methodical and exceptionally detailed in their Bible study, opinions and disciplined lifestyle.
            Eventually, the so-called Methodists started individual societies or classes for members of the Church of England who wanted to live a more religious life. '''
            
encoding = tokenizer.encode_plus(text=question,text_pair=paragraph)

inputs = encoding['input_ids']  #Token embeddings
sentence_embedding = encoding['token_type_ids']  #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens

start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(start_scores)

but I get this error at the last line:

Exception has occurred: TypeError
argmax(): argument 'input' (position 1) must be Tensor, not str
  File "D:\bert\QuestionAnswering.py", line 33, in <module>
    start_index = torch.argmax(start_scores)

I don't know what's wrong. can anyone help me?

Upvotes: 2

Views: 883

Answers (2)

maxim velikanov
maxim velikanov

Reputation: 86

Huggingface transformers provide a simple high-level way of running the model, as shown in this guide:

from transformers import pipeline

nlp = pipeline('question-answering', model=model, tokenizer=tokenizer)
print(nlp(question=question, context=paragraph, topk=5))

topk allows to select several top-scoring answers.

Upvotes: 2

Aman
Aman

Reputation: 8995

BertForQuestionAnswering returns a QuestionAnsweringModelOutput object.

Since you set the output of BertForQuestionAnswering to start_scores, end_scores, the return QuestionAnsweringModelOutput object is forced convert to a tuple of strings ('start_logits', 'end_logits') causing the type mismatch error.

The following should work:

outputs = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(outputs.start_logits)

Upvotes: 2

Related Questions