Reputation: 53
I am training a XLM-RoBERTa model for token classification using Huggingface Transformers. My maximum token length of the already fine-tuned model is 166. I truncated longer and padded shorter sequences in the training data. Now, during inference/prediction time I would like to predict all tokens, even in sequences longer than 166. But if I read the documentation correctly, overflowing tokens get thrown away. Is that correct? I am not completely sure what the "return_overflowing_tokens" and stride parameters do. Could they be used to split sequences that are too long into two or more shorter sequences?
I have already tried to split my text data into sentences to have smaller chunks, but some of them still exceed the max token length. It would be ideal, if overflowing tokens would be automatically added to an additional sequence.
Upvotes: 4
Views: 4883
Reputation: 19495
Let's say you have the following string:
from transformers import XLMRobertaTokenizerFast
model_id = "xlm-roberta-large-finetuned-conll03-english"
t = XLMRobertaTokenizerFast.from_pretrained(model_id)
sample = "this is an example and context is important to retrieve meaningful contextualized token embeddings from the self attention mechanism of the transformer"
print(f"this string has {len(t.tokenize(sample))} tokens")
Output:
this string has 32 tokens
The tokenizers max_length would truncate the text and your model would therefore never classify the truncated tokens:
encoded_max_length = t(sample, max_length=10, truncation=True).input_ids
print(len(encoded_max_length))
print(t.batch_decode(encoded_max_length))
Output:
10
['<s>', 'this', 'is', 'an', 'example', 'and', 'context', 'is', 'important', '</s>']
To also pass the truncated tokens to the model, you can use the return_overflowing_tokens:
encoded_overflow = t(sample, max_length=10, truncation=True, return_overflowing_tokens=True).input_ids
print([len(x) for x in encoded_overflow])
print(*t.batch_decode(encoded_overflow), sep="\n")
Output:
[10, 10, 10, 10]
<s> this is an example and context is important</s>
<s> to retrieve meaningful contextual</s>
<s>ized token embeddings from</s>
<s> the self attention mechanism of the transformer</s>
You might notice an issue here. Your model will probably face problems to generate meaningful embeddings (for your downstream task) for the tokens at the beginning and the end of each sentence since they are lacking context due to the hard-cut approach. The ized
token of the third sequence is a good example of this problem.
The standard approach for that problem is the sliding window approach which keeps some tokens of the previous sequence for the current sequence. You can control the sliding window with the stride parameter of the tokenizer:
encoded_overflow_stride = t(sample, max_length=10, truncation=True, stride=3, return_overflowing_tokens=True).input_ids
print([len(x) for x in encoded_overflow_stride])
print(*t.batch_decode(encoded_overflow_stride), sep="\n")
Output:
[10, 10, 10, 10, 10, 9]
<s> this is an example and context is important</s>
<s> context is important to retrieve meaning</s>
<s>trieve meaningful contextualized to</s>
<s>ualized token embeddings</s>
<s>embeddings from the self attention mechanism</s>
<s> self attention mechanism of the transformer</s>
Upvotes: 18