Reputation: 3790
I am using PytorchLightning and a ModelCheckpoint
which saves models with a formatted filename like filename="model_{epoch}-{val_acc:.2f}"
Later I want access these checkpoints again. For simplicity I want the best from save_top_k=N
.
As the filename is dynamic I wonder how can I retrieve the checkpoint files easily.
Is there a built-in attribute in the ModelCheckpoint
or the trainer that gives me the saved checkpoints?
For example like
checkpoint_callback.get_top_k_paths()
I know I can do it with glob
and model_dir
. I assume that callback has to keep track of them anyway, so I wonder what the built-in way is to access the model paths.
Upvotes: 3
Views: 2202
Reputation: 553
You could get the path of each checkpoint as you go. Running dir(trainer)
(and a few of its children) shows the dirpath
and filename
being used in the ModelCheckpoint
's on_save_checkpoint()
function. So you could write something like I have in my other answer but with this modified function:
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
super().on_save_checkpoint(trainer=trainer, pl_module=pl_module, checkpoint=checkpoint, outfile) # save the checkpoint
with open(outfile, mode='a') as outf: # write out path names as they come in
outf.write(trainer.checkpoint_callback.dirpath, trainer.checkpoint_callback.filename)
If you log the val_loss and use it in the name, you can just see what the top k of them are. Good luck!
Upvotes: 0
Reputation: 3790
-> All stored checkpoints can be found in ModelCheckpoint.best_k_models : Dict[str, Tensor]
where the keys are the paths and the values the metric that is tracked.
Additionally does ModelCheckpoint have these attributes:
best_model_path
best_model_score
, kth_best_model_path
, kth_value
, last_model_path
and best_k_models
.
These values are only guaranteed when model_checkpoint.dirpath
matches the one in in the checkpoints_state_dict["dirpath"]
, i.e. you did not change the directory, otherwise only best_model_path
is restored.
Otherwise as Aniket Maurya states you have to look at dirpath
or the parallel files in best_model_path
.
Upvotes: 0
Reputation: 380
you can retrieve the best model path after training from the checkpoint
# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path
To find all the checkpoints you can get the list of files in the dirpath
where the checkpoints are saved.
Upvotes: 2