Daraan
Daraan

Reputation: 3790

Get paths of saved checkpoints from Pytorch-Lightning ModelCheckpoint

I am using PytorchLightning and a ModelCheckpoint which saves models with a formatted filename like filename="model_{epoch}-{val_acc:.2f}"

Later I want access these checkpoints again. For simplicity I want the best from save_top_k=N. As the filename is dynamic I wonder how can I retrieve the checkpoint files easily.
Is there a built-in attribute in the ModelCheckpoint or the trainer that gives me the saved checkpoints? For example like

checkpoint_callback.get_top_k_paths()

I know I can do it with glob and model_dir. I assume that callback has to keep track of them anyway, so I wonder what the built-in way is to access the model paths.

Upvotes: 3

Views: 2202

Answers (3)

Drew Galbraith
Drew Galbraith

Reputation: 553

Solution

You could get the path of each checkpoint as you go. Running dir(trainer) (and a few of its children) shows the dirpath and filename being used in the ModelCheckpoint's on_save_checkpoint() function. So you could write something like I have in my other answer but with this modified function:

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        super().on_save_checkpoint(trainer=trainer, pl_module=pl_module, checkpoint=checkpoint, outfile)  # save the checkpoint

        with open(outfile, mode='a') as outf:  # write out path names as they come in
            outf.write(trainer.checkpoint_callback.dirpath, trainer.checkpoint_callback.filename)

If you log the val_loss and use it in the name, you can just see what the top k of them are. Good luck!

Upvotes: 0

Daraan
Daraan

Reputation: 3790

-> All stored checkpoints can be found in ModelCheckpoint.best_k_models : Dict[str, Tensor] where the keys are the paths and the values the metric that is tracked.

Additionally does ModelCheckpoint have these attributes: best_model_path best_model_score, kth_best_model_path, kth_value, last_model_path and best_k_models.


Note: when loading a checkpoint

These values are only guaranteed when model_checkpoint.dirpath matches the one in in the checkpoints_state_dict["dirpath"], i.e. you did not change the directory, otherwise only best_model_path is restored.

Otherwise as Aniket Maurya states you have to look at dirpath or the parallel files in best_model_path.

Upvotes: 0

Aniket Maurya
Aniket Maurya

Reputation: 380

you can retrieve the best model path after training from the checkpoint

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path

To find all the checkpoints you can get the list of files in the dirpath where the checkpoints are saved.

Upvotes: 2

Related Questions