Reputation: 81
I'm working in lstm with time-series data and I've observed a problem in the gradients of my network. I've one layer of 121 lstm cells. For each cell I've one input value and I get one output value. I work with a batch size of 121 values and I define lstm cell with batch_first = True, so my outputs are [batch,timestep,features].
Once I've the outputs (tensor of size [121,121,1]), I calculate the loss using MSELoss() and I backpropagate it. And here appears the problem. Looking the gradients of each cell, I notice that the gradients of first 100 cells (more or less) are null.
In theory, if I'm not wrong, when I backpropagate the error I calculate a gradient for each output, so I have a gradient for each cell. If that is true, I can't understand why in the first cells they are zero.
Does somebody knows what is happening?
Thank you!
PS.: I show you the gradient flow of the last cells:
Update:
As I tried to ask before, I still have a question about LSTM backpropagation. As you can see from the image below, in one cell, apart from the gradients that come from other cells, I think there’s also another gradient form itself.
For example, let’s look at the cell 1. I get the output y1 and I calculate the loss E1. I do the same with other cells. So, when I backpropagate in cell 1, I get dE2/dy2 * dy2/dh1 * dh1/dw1 + ...
which are the gradients related to following cells in the network (BPTT) as @kmario23 and @DavidNg explained. And I also have the gradient related to E1 (dE1/dy1 * dy1/dw1
). The first gradients can vanish during the flow, but this one not.
So, to sum up, although having a long layer of lstm cells, to my understand I have a gradient related only to each cell, therefore I don’t understand why I have gradients equal to zero. What does it happens with the error related to E1? Why is only bptt calculated?
Upvotes: 4
Views: 2032
Reputation: 1698
I have been dealing with these problems several times. And here is my advice:
Use smaller number of timesteps
The hidden output of the previous timestep is passed to the current steps and multiplied by the weights. When you multiply several times, the gradient will explode or vanish exponentially with the number of timesteps. Let's say:
# it's exploding
1.01^121 = 101979 # imagine how large it is when the weight is not 1.01
# or it's vanishing
0.9^121 = 2.9063214161987074e-06 # ~ 0.0 when we init the weight smaller than 1.0
For less cluttering, I take an example of simple RNNCell - with weights W_ih
and W_hh
with no bias. And in your case, W_hh
is just a single number but the case might generalize for any matrix W_hh
. We use the indentity
activation as well.
If we unroll the RNN along all the time steps K=3
, we get:
h_1 = W_ih * x_0 + W_hh * h_0 (1)
h_2 = W_ih * x_1 + W_hh * h_1 (2)
h_3 = W_ih * x_2 + W_hh * h_2 (3)
Therefore, when we need to update the weights W_hh
, we have to accummulate all the gradients in the step (1), (2), (3).
grad(W_hh) = grad(W_hh at step 1) + grad(W_hh at step 2) + grad(W_hh at step 3)
# step 3
grad(W_hh at step3) = d_loss/d(h_3) * d(h_3)/d(W_hh)
grad(W_hh at step3) = d_loss/d(h_3) * h_2
# step 2
grad(W_hh at step2) = d_loss/d(h_2) * d(h_2)/d(W_hh)
grad(W_hh at step2) = d_loss/d(h_3) * d_(h_3)/d(h_2) * d(h_2)/d(W_hh)
grad(W_hh at step2) = d_loss/d(h_3) * d_(h_3)/d(h_2) * h_1
# step 1
grad(W_hh at step1) = d_loss/d(h_1) * d(h_1)/d(W_hh)
grad(W_hh at step1) = d_loss/d(h_3) * d(h_3)/d(h_2) * d(h_2)/d(h_1) * d(h_1)/d(W_hh)
grad(W_hh at step1) = d_loss/d(h_3) * d(h_3)/d(h_2) * d(h_2)/d(h_1) * h_0
# As we also:
d(h_i)/d(h_i-1) = W_hh
# Then:
grad(W_hh at step3) = d_loss/d(h_3) * h_2
grad(W_hh at step2) = d_loss/d(h_3) * W_hh * h_1
grad(W_hh at step1) = d_loss/d(h_3) * W_hh * W_hh * h_0
Let d_loss/d(h_3) = v
# We accumulate all gradients for W_hh
grad(W_hh) = v * h_2 + v * W_hh * h_1 + v * W_hh * W_hh * h_0
# If W_hh is initialized too big >> 1.0, grad(W_hh) explode quickly (-> infinity).
# If W_hh is initialized too small << 1.0, grad(W_hh) vanishes quickly (-> 0), since h_2, h_1 are vanishing after each forward step (exponentially)
Although LSTM cell has different gates (like forget gate reduce irrelevantly lengthy dependency in timestep) to mitigate these problems, it will be affected by long nummber of timesteps. It still a big question for sequential data about how to design network architecture for learning long dependency.
To avoid the problems, just reduce the number of timesteps (seq_len
) into subsequences.
bs = 121
seq_len = 121
new_seq_len = seq_len // k # k = 2, 2.5 or anything to experiment
X (of [bs,seq_len, 1]) -> [ X1[bs, new_seq_len, 1], X2[bs, new_seq_len, 1],...]
Then, you pass each small batch Xi
into the model, such that the initial hidden is h_(i-1)
which is the hidden output of previous batch `X(i-1)
h_i = model(Xi, h_(i-1))
So it will help the model to learn some long dependency as the model of 121
timesteps.
Upvotes: 4