Penguin
Penguin

Reputation: 2441

Why is saving state_dict getting slower as training progresses?

I'm saving my model's and optimizer's state dict as follows:

if epoch % 50000 == 0:
  #checkpoint save every 50000 epochs
  print('\nSaving model... Loss is: ', loss)
  torch.save({
      'epoch': epoch,
      'model': self.state_dict(),
      'optimizer_state_dict': self.optimizer.state_dict(),
      'scheduler': self.scheduler.state_dict(),
      'loss': loss,
      'losses': self.losses,
      }, PATH)

When I first start the training it saves in less than 5 seconds. However, after a couple of hours of training it takes over a two minutes to save. The only reason I could think of is the list of losses. But I can't see how that would increase the time by that much.

Update 1:
I have my losses as:

self.losses = []

I'm appending the loss at each epoch to this list as follows:

    #... loss calculation
    loss.backward()
    self.optimizer.step()
    self.scheduler.step() 

    self.losses.append(loss)

Upvotes: 0

Views: 646

Answers (1)

trialNerror
trialNerror

Reputation: 3573

As mentionned in the comments, the instruction

self.losses.append(loss) 

is definitely the culprit, and shoud be replaced with

self.losses.append(loss.item())

The reason is that when you store the tensor loss, you also store the whole computational graph alongside (all the information that is required to perform the backprop). In other words, you are not merely storing a tensor, but also the pointers to all the tensors that have been involved in the computation of the loss and their relations (which ones were added, multiplied etc). So it will grow really big really fast.

When you do loss.item() (or loss.detach(), which would work as well), you detach the tensor from the computational graph, and thus you only store what you intended : the loss value itself, as a simple float value

Upvotes: 2

Related Questions