Reputation: 333
I'm following the course material of Hugging Face: https://huggingface.co/course/chapter7/3?fw=pt (great stuff btw!). However now I'm running in an issue.
When I run the training and eval using the default data_collator everything goes fine. But when I use the custom whole_word_masking_data_collator it doesn't work because it misses the key "word_ids".
My data is as follows:
DatasetDict({
train: Dataset({
features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids', 'word_ids'],
num_rows: 30639
})
test: Dataset({
features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids', 'word_ids'],
num_rows: 29946
})
unsupervised: Dataset({
features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids', 'word_ids'],
num_rows: 61465
})
})
When I use my whole_word_masking_data_collator as follows everything is fine:
whole_word_masking_data_collator([lm_datasets["train"][0]])
However when I use it in a trainer like this:
from transformers import Trainer
trainer = Trainer(
model=masked_model,
args=training_args,
train_dataset=lm_datasets["train"],
eval_dataset=lm_datasets["test"],
data_collator=whole_word_masking_data_collator,
)
It gives me the following error:
KeyError: 'word_ids'
Which I find bizar because this key is clearly pressend in the data and the whole_word_masking_data_collator function works fine standalone.
When I checked the keys in my function I did find that the key is indeed missing. I only got these keys:
dict_keys(['attention_mask', 'input_ids', 'labels', 'token_type_ids'])
So my question is: Were in my code does the key "word_ids" go missing?
Upvotes: 2
Views: 2649
Reputation: 21
I solved this, by setting remove_unused_columns=False in Trainer args
Upvotes: 2
Reputation: 308
In case you're still facing this problem, I found the solution in the same doc:
By default, the Trainer will remove any columns that are not part of the model’s forward() method. This means that if you’re using the whole word masking collator, you’ll also need to set remove_unused_columns=False to ensure we don’t lose the word_ids column during training.
Upvotes: 5