MAC
MAC

Reputation: 1515

Getting error with Pytorch lightning when passing model checkpoint

I am training a multi-label classification problem using Hugging face models. I am using Pytorch Lightning to train the model.

Here is the code:

And early stopping triggers when the loss hasn't improved for the last

early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

We can start the training process:

checkpoint_callback = ModelCheckpoint(
  dirpath="checkpoints",
  filename="best-checkpoint",
  save_top_k=1,
  verbose=True,
  monitor="val_loss",
  mode="min"
)


trainer = pl.Trainer(
  logger=logger,
  callbacks=[early_stopping_callback],
  max_epochs=N_EPOCHS,
 checkpoint_callback=checkpoint_callback,
  gpus=1,
  progress_bar_refresh_rate=30
)
# checkpoint_callback=checkpoint_callback,

As soon as I run this, I get this error:

~/.local/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py in _configure_checkpoint_callbacks(self, checkpoint_callback)
     75             if isinstance(checkpoint_callback, Callback):
     76                 error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
---> 77             raise MisconfigurationException(error_msg)
     78         if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
     79             raise MisconfigurationException(

MisconfigurationException: Invalid type provided for checkpoint_callback: Expected bool but received <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>. Pass callback instances to the `callbacks` argument in the Trainer constructor instead.

How can I fix this issue?

Upvotes: 0

Views: 7234

Answers (1)

Ivan
Ivan

Reputation: 40648

You can look up the description of the checkpoint_callback argument in the documentation page of pl.Trainer:

checkpoint_callback (bool) – If True, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in callbacks.

You shouldn't pass your custom ModelCheckpoint to this argument. I believe what you are looking to do is to pass both the EarlyStopping and ModelCheckpoint in the callbacks list:

early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    verbose=True,
    monitor="val_loss",
    mode="min")

trainer = pl.Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=N_EPOCHS,
    gpus=1,
    progress_bar_refresh_rate=30)

Upvotes: 5

Related Questions