Reputation: 1194
I have this error while trying to train with trainer.train() with trainer from Huggingface transformers: RuntimeError: Expected a 'mps:0' generator device but found 'cpu'
. This is the trainer:
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice_train,
eval_dataset=common_voice_test,
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
When I print the model's device and training arguments device, it gives me mps (not cpu). Moreover, I am using the same conda environment which I have used earlier (so no torch or transformers update) and it worked fine, showed no such an error. I don't now if I have updated the mac os in the meanwhile, and if it can influence the situation.
Upvotes: 0
Views: 321