Novak
Novak

Reputation: 4779

PyTorch 0.4 LSTM: Why does each epoch get slower?

I have a toy model of a PyTorch 0.4 LSTM on a GPU. The overall idea of the toy problem is that I define a single 3-vector as an input, and define a rotation matrix R. The ground truth targets are then a sequence of vectors: At T0, the input vector; at T1 the input vector rotated by R; at T2 the input rotated by R twice, etc. (The input is padded the output length with zero-inputs after T1)

The loss is the average L2 difference between ground truth and outputs. The rotation matrix, construction of the input/output data, and loss functions are probably not of interest, and not shown here.

Never mind that the results are pretty terrible: Why does this become successively slower with each passing epoch?!

I've shown on-GPU information below, but this happens on the CPU as well (only with larger times.) The time to execute ten epochs of this silly little thing grows rapidly. It's quite noticeable just watching the numbers scroll by.

epoch:   0,     loss: 0.1753,   time previous: 33:28.616360 time now: 33:28.622033  time delta: 0:00:00.005673
epoch:  10,     loss: 0.2568,   time previous: 33:28.622033 time now: 33:28.830665  time delta: 0:00:00.208632
epoch:  20,     loss: 0.2092,   time previous: 33:28.830665 time now: 33:29.324966  time delta: 0:00:00.494301
epoch:  30,     loss: 0.2663,   time previous: 33:29.324966 time now: 33:30.109241  time delta: 0:00:00.784275
epoch:  40,     loss: 0.1965,   time previous: 33:30.109241 time now: 33:31.184024  time delta: 0:00:01.074783
epoch:  50,     loss: 0.2232,   time previous: 33:31.184024 time now: 33:32.556106  time delta: 0:00:01.372082
epoch:  60,     loss: 0.1258,   time previous: 33:32.556106 time now: 33:34.215477  time delta: 0:00:01.659371
epoch:  70,     loss: 0.2237,   time previous: 33:34.215477 time now: 33:36.173928  time delta: 0:00:01.958451
epoch:  80,     loss: 0.1076,   time previous: 33:36.173928 time now: 33:38.436041  time delta: 0:00:02.262113
epoch:  90,     loss: 0.1194,   time previous: 33:38.436041 time now: 33:40.978748  time delta: 0:00:02.542707
epoch: 100,     loss: 0.2099,   time previous: 33:40.978748 time now: 33:43.844310  time delta: 0:00:02.865562

The model:

class Sequence(torch.nn.Module):
def __init__ (self):
    super(Sequence, self).__init__()

    self.lstm1 = nn.LSTM(3,30)
    self.lstm2 = nn.LSTM(30,300)
    self.lstm3 = nn.LSTM(300,30)
    self.lstm4 = nn.LSTM(30,3)

    self.hidden1 = self.init_hidden(dim=30)
    self.hidden2 = self.init_hidden(dim=300)
    self.hidden3 = self.init_hidden(dim=30)
    self.hidden4 = self.init_hidden(dim=3)

    self.dense   = torch.nn.Linear(30, 3)   
    self.relu    = nn.LeakyReLU()

def init_hidden(self, dim):
    return (torch.zeros(1, 1, dim).to(device)  ,torch.zeros(1, 1, dim).to(device)  )      

def forward(self, inputs):
    out1, self.hidden1 = self.lstm1(inputs, self.hidden1)
    out2, self.hidden2 = self.lstm2(out1,   self.hidden2)
    out3, self.hidden3 = self.lstm3(out2,   self.hidden3)
    #out4, self.hidden4 = self.lstm4(out3,   self.hidden4)   

    # This is intended to act as a dense layer on the output of the LSTM
    out4               = self.relu(self.dense(out3))        

    return out4

The training loop:

sequence = Sequence().to(device)

criterion = L2_Loss()
optimizer = torch.optim.Adam(sequence.parameters())
_, _, _, R = getRotation(np.pi/27, np.pi/26, np.pi/25)

losses = []
date1 = datetime.datetime.now()
for epoch in range(1001):
    # Define input as a Variable-- each row of 3 is a vector, a distinct input
    # Define target directly from input by applicatin of rotation vector
    # Define predictions by running input through model 

    inputs       = getInput(25)
    targets      = getOutput(inputs, R)

    inputs       = torch.cat(inputs).view(len(inputs), 1, -1).to(device)
    targets      = torch.cat(targets).view(len(targets), 1, -1).to(device)

    target_preds = sequence(inputs)
    target_preds = target_preds.view(len(target_preds), 1, -1)
    loss = criterion(targets, target_preds).to(device)

    losses.append(loss.data[0])
    if (epoch % 10 == 0):
        date2 = datetime.datetime.now()
        print("epoch: %3d, \tloss: %6.4f, \ttime previous: %s\ttime now: %s\ttime delta: %s" % (epoch, loss.data[0], date1.strftime("%M:%S.%f"), date2.strftime("%M:%S.%f"), date2 - date1))
        date1 = date2
    # Zero out the grads, run the loss backward, and optimize on the grads
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step() 

Upvotes: 0

Views: 1455

Answers (1)

Novak
Novak

Reputation: 4779

Short answer: Because we did not detach the hidden layers, and therefore the system kept backpropagating farther and father back through time, taking up more memory and requiring more time.

Long answer: This answer is meant to work without teacher forcing. "Teacher forcing" is when all inputs at all time-steps are "ground truth" input values. In contrast, without teacher forcing, the input of each time step is the output from the previous time step, no matter how early in the training regime (and therefore, how wildly erratic) that data might be.

This is a bit of a manual operation in PyTorch, requiring us to keep track of not only the output, but the hidden state of the network at each step, so we can provide it to the next. Detachment has to happen, not at every time step, but at the beginning of each sequence. A method that seems to work is to define a "detach" method as part of the Sequence model (which will manually detach all the hidden layers), and call it explicitly after the optimizer.step().

This prevents the gradual accumulation of the hidden states, prevents the gradual slowdown, and still seems to train the network.

I cannot truly vouch for it because I have only employed it on a toy model, not a real problem.

Note 1: There are probably better ways to factor the initialization of the network and use that instead of a manual detach, as well.

Note2: The loss.backward(retain_graph=True) statement retains the graph because error messages suggested it. Once the detach is enacted, that warning disappears.

I leave this answer un-accepted in the hopes that someone knowledgeable will add their expertise.

Upvotes: 1

Related Questions