Exa
Exa

Reputation: 468

Pytorch: load checkpoint from batch without iterating over dataset again

Instead of loading from an epoch wise checkpoint I need to be able to load from a batch. I am aware that this is not optimal but since I only have limited training time before my training gets interrupted (google colab free version) I need to be able to load from the batch it stopped or around that batch.

I also do not want to iterate over all data again but continue with the data the model has not seen yet.

My current approach which does not work:

def save_checkpoint(state, file=checkpoint_file):
  torch.save(state, file)

def load_checkpoint(checkpoint):
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  train_loss = checkpoint['train_loss']
  val_loss = checkpoint['val_loss']
  epoch = checkpoint['epoch']
  step = checkpoint['step']
  batch = checkpoint['batch']
  return model, optimizer, train_loss, val_loss, epoch, step, batch

While it does load the weights from where it stopped, it iterates over all data again.

Also, do I even need to capture train_loss and val_loss? I cannot see a difference in the loss being output when I include them or not. Thus, I assume it is already included in model.load_state_dict (?)

I assume capturing step and batch won't work this way and I actually need to include some sort of index tracker within my class DataSet ? I do already have this within the DataSet class

   def __getitem__(self, idx):
    question = self.data_qs[idx]
    answer1 = self.data_a1s[idx]
    answer2 = self.data_a2s[idx]
    target = self.targets[idx]

So, could this be useful?

Upvotes: 0

Views: 564

Answers (1)

Kaushik Roy
Kaushik Roy

Reputation: 1685

You can achieve your goal by creating a custom Dataset class with a property self.start_index=step*batch and in your __getitem__ function the new index should be (self.start_index+idx)%len(self.data_qs) If you create your Dataloader with shuffle=False then this tricks will work.

Additionally, With shuffle=True you can maintain a index mapper and needs to verify.

Upvotes: 1

Related Questions