Tessa W
Tessa W

Reputation: 53

How can I handle overflowing tokens in Huggingface Transformer model?

I am training a XLM-RoBERTa model for token classification using Huggingface Transformers. My maximum token length of the already fine-tuned model is 166. I truncated longer and padded shorter sequences in the training data. Now, during inference/prediction time I would like to predict all tokens, even in sequences longer than 166. But if I read the documentation correctly, overflowing tokens get thrown away. Is that correct? I am not completely sure what the "return_overflowing_tokens" and stride parameters do. Could they be used to split sequences that are too long into two or more shorter sequences?

I have already tried to split my text data into sentences to have smaller chunks, but some of them still exceed the max token length. It would be ideal, if overflowing tokens would be automatically added to an additional sequence.

Upvotes: 4

Views: 4883

Answers (1)

cronoik
cronoik

Reputation: 19495

Let's say you have the following string:

from transformers import  XLMRobertaTokenizerFast

model_id = "xlm-roberta-large-finetuned-conll03-english"
t = XLMRobertaTokenizerFast.from_pretrained(model_id)

sample = "this is an example and context is important to retrieve meaningful contextualized token embeddings from the self attention mechanism of the transformer"
print(f"this string has {len(t.tokenize(sample))} tokens")

Output:

this string has 32 tokens

The tokenizers max_length would truncate the text and your model would therefore never classify the truncated tokens:

encoded_max_length = t(sample, max_length=10, truncation=True).input_ids

print(len(encoded_max_length))
print(t.batch_decode(encoded_max_length))

Output:

10
['<s>', 'this', 'is', 'an', 'example', 'and', 'context', 'is', 'important', '</s>']

To also pass the truncated tokens to the model, you can use the return_overflowing_tokens:

encoded_overflow = t(sample, max_length=10, truncation=True, return_overflowing_tokens=True).input_ids

print([len(x) for x in encoded_overflow])
print(*t.batch_decode(encoded_overflow), sep="\n")

Output:

[10, 10, 10, 10]
<s> this is an example and context is important</s>
<s> to retrieve meaningful contextual</s>
<s>ized token embeddings from</s>
<s> the self attention mechanism of the transformer</s>

You might notice an issue here. Your model will probably face problems to generate meaningful embeddings (for your downstream task) for the tokens at the beginning and the end of each sentence since they are lacking context due to the hard-cut approach. The ized token of the third sequence is a good example of this problem.

The standard approach for that problem is the sliding window approach which keeps some tokens of the previous sequence for the current sequence. You can control the sliding window with the stride parameter of the tokenizer:

encoded_overflow_stride = t(sample, max_length=10, truncation=True, stride=3, return_overflowing_tokens=True).input_ids

print([len(x) for x in encoded_overflow_stride])
print(*t.batch_decode(encoded_overflow_stride), sep="\n")

Output:

[10, 10, 10, 10, 10, 9]
<s> this is an example and context is important</s>
<s> context is important to retrieve meaning</s>
<s>trieve meaningful contextualized to</s>
<s>ualized token embeddings</s>
<s>embeddings from the self attention mechanism</s>
<s> self attention mechanism of the transformer</s>

Upvotes: 18

Related Questions