Rick Vink
Rick Vink

Reputation: 333

key dataset lost during training using the Hugging Face Trainer

I'm following the course material of Hugging Face: https://huggingface.co/course/chapter7/3?fw=pt (great stuff btw!). However now I'm running in an issue.

When I run the training and eval using the default data_collator everything goes fine. But when I use the custom whole_word_masking_data_collator it doesn't work because it misses the key "word_ids".

My data is as follows:

DatasetDict({
train: Dataset({
    features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids', 'word_ids'],
    num_rows: 30639
})
test: Dataset({
    features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids', 'word_ids'],
    num_rows: 29946
})
unsupervised: Dataset({
    features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids', 'word_ids'],
    num_rows: 61465
})
})

When I use my whole_word_masking_data_collator as follows everything is fine:

whole_word_masking_data_collator([lm_datasets["train"][0]])

However when I use it in a trainer like this:

from transformers import Trainer

trainer = Trainer(
    model=masked_model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["test"],
    data_collator=whole_word_masking_data_collator,
)

It gives me the following error:

KeyError: 'word_ids'

Which I find bizar because this key is clearly pressend in the data and the whole_word_masking_data_collator function works fine standalone.

When I checked the keys in my function I did find that the key is indeed missing. I only got these keys:

dict_keys(['attention_mask', 'input_ids', 'labels', 'token_type_ids'])

So my question is: Were in my code does the key "word_ids" go missing?

Upvotes: 2

Views: 2649

Answers (2)

Yassine Elkheir
Yassine Elkheir

Reputation: 21

I solved this, by setting remove_unused_columns=False in Trainer args

Upvotes: 2

kpriya
kpriya

Reputation: 308

In case you're still facing this problem, I found the solution in the same doc:

By default, the Trainer will remove any columns that are not part of the model’s forward() method. This means that if you’re using the whole word masking collator, you’ll also need to set remove_unused_columns=False to ensure we don’t lose the word_ids column during training.

Upvotes: 5

Related Questions