Pawel Faron
Pawel Faron

Reputation: 312

Why am I getting Nan after adding relu activation in LSTM?

I have simple LSTM network that looks roughly like this:

lstm_activation = tf.nn.relu

cells_fw = [LSTMCell(num_units=100, activation=lstm_activation), 
            LSTMCell(num_units=10, activation=lstm_activation)]

stacked_cells_fw = MultiRNNCell(cells_fw)

_, states = tf.nn.dynamic_rnn(cell=stacked_cells_fw,
                              inputs=embedding_layer,
                              sequence_length=features['length'],
                              dtype=tf.float32)

output_states = [s.h for s in states]
states = tf.concat(output_states, 1)

My question is. When I don't use activation (activation=None) or use tanh everything works but when I switch relu I'm keep getting "NaN loss during training", why is that?. It's 100% reproducible.

Upvotes: 1

Views: 799

Answers (1)

gorjan
gorjan

Reputation: 5555

When you use the relu activation function inside the lstm cell, it is guaranteed that all the outputs from the cell, as well as the cell state, will be strictly >= 0. Because of that, your gradients become extremely large and are exploding. For example, run the following code snippet and observe that the outputs are never < 0.

X = np.random.rand(4,3,2)
lstm_cell = tf.nn.rnn_cell.LSTMCell(5, activation=tf.nn.relu)
hidden_states, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=X, dtype=tf.float64)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(hidden_states))

Upvotes: 3

Related Questions