Reputation: 4779
I have a toy model of a PyTorch 0.4 LSTM on a GPU. The overall idea of the toy problem is that I define a single 3-vector as an input, and define a rotation matrix R. The ground truth targets are then a sequence of vectors: At T0, the input vector; at T1 the input vector rotated by R; at T2 the input rotated by R twice, etc. (The input is padded the output length with zero-inputs after T1)
The loss is the average L2 difference between ground truth and outputs. The rotation matrix, construction of the input/output data, and loss functions are probably not of interest, and not shown here.
Never mind that the results are pretty terrible: Why does this become successively slower with each passing epoch?!
I've shown on-GPU information below, but this happens on the CPU as well (only with larger times.) The time to execute ten epochs of this silly little thing grows rapidly. It's quite noticeable just watching the numbers scroll by.
epoch: 0, loss: 0.1753, time previous: 33:28.616360 time now: 33:28.622033 time delta: 0:00:00.005673
epoch: 10, loss: 0.2568, time previous: 33:28.622033 time now: 33:28.830665 time delta: 0:00:00.208632
epoch: 20, loss: 0.2092, time previous: 33:28.830665 time now: 33:29.324966 time delta: 0:00:00.494301
epoch: 30, loss: 0.2663, time previous: 33:29.324966 time now: 33:30.109241 time delta: 0:00:00.784275
epoch: 40, loss: 0.1965, time previous: 33:30.109241 time now: 33:31.184024 time delta: 0:00:01.074783
epoch: 50, loss: 0.2232, time previous: 33:31.184024 time now: 33:32.556106 time delta: 0:00:01.372082
epoch: 60, loss: 0.1258, time previous: 33:32.556106 time now: 33:34.215477 time delta: 0:00:01.659371
epoch: 70, loss: 0.2237, time previous: 33:34.215477 time now: 33:36.173928 time delta: 0:00:01.958451
epoch: 80, loss: 0.1076, time previous: 33:36.173928 time now: 33:38.436041 time delta: 0:00:02.262113
epoch: 90, loss: 0.1194, time previous: 33:38.436041 time now: 33:40.978748 time delta: 0:00:02.542707
epoch: 100, loss: 0.2099, time previous: 33:40.978748 time now: 33:43.844310 time delta: 0:00:02.865562
The model:
class Sequence(torch.nn.Module):
def __init__ (self):
super(Sequence, self).__init__()
self.lstm1 = nn.LSTM(3,30)
self.lstm2 = nn.LSTM(30,300)
self.lstm3 = nn.LSTM(300,30)
self.lstm4 = nn.LSTM(30,3)
self.hidden1 = self.init_hidden(dim=30)
self.hidden2 = self.init_hidden(dim=300)
self.hidden3 = self.init_hidden(dim=30)
self.hidden4 = self.init_hidden(dim=3)
self.dense = torch.nn.Linear(30, 3)
self.relu = nn.LeakyReLU()
def init_hidden(self, dim):
return (torch.zeros(1, 1, dim).to(device) ,torch.zeros(1, 1, dim).to(device) )
def forward(self, inputs):
out1, self.hidden1 = self.lstm1(inputs, self.hidden1)
out2, self.hidden2 = self.lstm2(out1, self.hidden2)
out3, self.hidden3 = self.lstm3(out2, self.hidden3)
#out4, self.hidden4 = self.lstm4(out3, self.hidden4)
# This is intended to act as a dense layer on the output of the LSTM
out4 = self.relu(self.dense(out3))
return out4
The training loop:
sequence = Sequence().to(device)
criterion = L2_Loss()
optimizer = torch.optim.Adam(sequence.parameters())
_, _, _, R = getRotation(np.pi/27, np.pi/26, np.pi/25)
losses = []
date1 = datetime.datetime.now()
for epoch in range(1001):
# Define input as a Variable-- each row of 3 is a vector, a distinct input
# Define target directly from input by applicatin of rotation vector
# Define predictions by running input through model
inputs = getInput(25)
targets = getOutput(inputs, R)
inputs = torch.cat(inputs).view(len(inputs), 1, -1).to(device)
targets = torch.cat(targets).view(len(targets), 1, -1).to(device)
target_preds = sequence(inputs)
target_preds = target_preds.view(len(target_preds), 1, -1)
loss = criterion(targets, target_preds).to(device)
losses.append(loss.data[0])
if (epoch % 10 == 0):
date2 = datetime.datetime.now()
print("epoch: %3d, \tloss: %6.4f, \ttime previous: %s\ttime now: %s\ttime delta: %s" % (epoch, loss.data[0], date1.strftime("%M:%S.%f"), date2.strftime("%M:%S.%f"), date2 - date1))
date1 = date2
# Zero out the grads, run the loss backward, and optimize on the grads
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
Upvotes: 0
Views: 1455
Reputation: 4779
Short answer: Because we did not detach the hidden layers, and therefore the system kept backpropagating farther and father back through time, taking up more memory and requiring more time.
Long answer: This answer is meant to work without teacher forcing. "Teacher forcing" is when all inputs at all time-steps are "ground truth" input values. In contrast, without teacher forcing, the input of each time step is the output from the previous time step, no matter how early in the training regime (and therefore, how wildly erratic) that data might be.
This is a bit of a manual operation in PyTorch, requiring us to keep track of not only the output, but the hidden state of the network at each step, so we can provide it to the next. Detachment has to happen, not at every time step, but at the beginning of each sequence. A method that seems to work is to define a "detach" method as part of the Sequence model (which will manually detach all the hidden layers), and call it explicitly after the optimizer.step().
This prevents the gradual accumulation of the hidden states, prevents the gradual slowdown, and still seems to train the network.
I cannot truly vouch for it because I have only employed it on a toy model, not a real problem.
Note 1: There are probably better ways to factor the initialization of the network and use that instead of a manual detach, as well.
Note2: The loss.backward(retain_graph=True)
statement retains the graph because error messages suggested it. Once the detach is enacted, that warning disappears.
I leave this answer un-accepted in the hopes that someone knowledgeable will add their expertise.
Upvotes: 1