Reputation: 11
I am writing a Question Answering system using pre-trained BERT
with a linear layer and a softmax
layer on top. When following the templates available on the net the labels of one example usually only consists of one answer_start_index
and one answer_end_index
. For example, from Huggingface
when instantiating a SQUADFeatures
object:
```
self.input_ids = input_ids
self.attention_mask = attention_mask
self.token_type_ids = token_type_ids
self.cls_index = cls_index
self.p_mask = p_mask
self.example_index = example_index
self.unique_id = unique_id
self.paragraph_len = paragraph_len
self.token_is_max_context = token_is_max_context
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
self.qas_id = qas_id
```
However, in my own dataset I have examples where the answer word is found at several locations in the context, i.e. there may be several correct spans constituting the answer.
My problem is I don't know how to manage such examples? In the templates available on the net labels are usually in a list, say:
In my case this may look like:
In other words, I do not have a list containing one label per example, but a list containing either single-labels or a list of "labels" for an example, i.e. a list consisting of lists.
When following other templates the next step in the process is:
```
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
token_type_ids = torch.cat(token_type_ids, dim=0)
span_starts = torch(span_starts) #Something like this
span_ends = torch(span_ends) #Something like this
```
However this of course (?) raises an error as my span_start lists and span_end lists does not contain only single-items but sometimes a list within the list.
Anyone have an idea on how I can tackle this problem? Should I only use examples where there's only one span constituting the answer present in the context?
If I work around the torch-error, will the backpropagation / evaluation/ computation of loss still work?
Thank You! /B
Upvotes: 1
Views: 1508
Reputation: 46351
Have you checked the code
from transformers import BertTokenizer, BertForQuestionAnswering
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
encoding = tokenizer.encode_plus(question, text)
input_ids, token_type_ids = encoding["input_ids"], encoding["token_type_ids"]
start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
assert answer == "a nice puppet"
I am not sure if this is the best way, but you may check instead of argmax to use topk
, and check if this correspond to the correct answer.
t = torch.LongTensor([0,1,2,3,4,5,6,7,8,9])
t
_, indices = t.topk(4)
indices#([9, 8, 7, 6])
Upvotes: 0