Reputation: 468
Instead of loading from an epoch wise checkpoint I need to be able to load from a batch. I am aware that this is not optimal but since I only have limited training time before my training gets interrupted (google colab free version) I need to be able to load from the batch it stopped or around that batch.
I also do not want to iterate over all data again but continue with the data the model has not seen yet.
My current approach which does not work:
def save_checkpoint(state, file=checkpoint_file):
torch.save(state, file)
def load_checkpoint(checkpoint):
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
train_loss = checkpoint['train_loss']
val_loss = checkpoint['val_loss']
epoch = checkpoint['epoch']
step = checkpoint['step']
batch = checkpoint['batch']
return model, optimizer, train_loss, val_loss, epoch, step, batch
While it does load the weights from where it stopped, it iterates over all data again.
Also, do I even need to capture train_loss
and val_loss
? I cannot see a difference in the loss being output when I include them or not. Thus, I assume it is already included in model.load_state_dict
(?)
I assume capturing step and batch won't work this way and I actually need to include some sort of index tracker within my class DataSet
? I do already have this within the DataSet
class
def __getitem__(self, idx):
question = self.data_qs[idx]
answer1 = self.data_a1s[idx]
answer2 = self.data_a2s[idx]
target = self.targets[idx]
So, could this be useful?
Upvotes: 0
Views: 564
Reputation: 1685
You can achieve your goal by creating a custom Dataset class with a property self.start_index=step*batch
and in your __getitem__
function the new index should be (self.start_index+idx)%len(self.data_qs)
If you create your Dataloader with shuffle=False
then this tricks will work.
Additionally, With shuffle=True
you can maintain a index mapper and needs to verify.
Upvotes: 1