Reputation: 487
Assuming training is finished: what values does Keras use for the 0th cell state and hidden states at inference (in LSTM and RNN layers)? I could think of at least three scenarios, and could not find any conclusive answer in the documentation:
(a) The initial states are learned and then used for all predictions
(b) or the initial states are always set at zero
(c) the initial states are always random (let's hope not...?)
Upvotes: 4
Views: 1201
Reputation: 19776
If using LSTM(stateful=True)
, hidden states are initialized to zero, change with fit
or predict
, and are kept at whatever they are until .reset_states()
is called. If LSTM(stateful=False)
, states are reset after fitting/predicting/etc each batch.
This can be verified from the .reset_states()
source code, and by direct inspection; both for stateful=True
below. For more info on how states are passed, see this answer.
Direct inspection:
batch_shape = (2, 10, 4)
model = make_model(batch_shape)
X = np.random.randn(*batch_shape)
y = np.random.randint(0, 2, (batch_shape[0], 1))
show_lstm_states("STATES INITIALIZED")
model.train_on_batch(X, y)
show_lstm_states("STATES AFTER TRAIN")
model.reset_states()
show_lstm_states("STATES AFTER RESET")
model.predict(X)
show_lstm_states("STATES AFTER PREDICT")
Output:
STATES INITIALIZED
[[0. 0. 0. 0.]
[0. 0. 0. 0.]]
[[0. 0. 0. 0.]
[0. 0. 0. 0.]]
STATES AFTER TRAIN
[[0.12061571 0.03639204 0.20810013 0.05309075]
[0.01832913 0.00062357 0.10566339 0.60108346]]
[[0.21241754 0.0773523 0.37392718 0.15590034]
[0.08496398 0.00112716 0.23814857 0.95995367]]
STATES AFTER RESET
[[0. 0. 0. 0.]
[0. 0. 0. 0.]]
[[0. 0. 0. 0.]
[0. 0. 0. 0.]]
STATES AFTER PREDICT
[[0.12162527 0.03720453 0.20628096 0.05421837]
[0.01849432 0.00064993 0.1045063 0.6097021 ]]
[[0.21398112 0.07894284 0.3709934 0.15928769]
[0.08605779 0.00117485 0.23606434 0.97212094]]
Functions / imports used:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense, LSTM
from tensorflow.keras.models import Model
import numpy as np
def make_model(batch_shape):
ipt = Input(batch_shape=batch_shape)
x = LSTM(4, stateful=True, activation='relu')(ipt)
out = Dense(1, activation='sigmoid')(x)
model = Model(ipt, out)
model.compile('adam', 'binary_crossentropy')
return model
def show_lstm_states(txt=''):
print('\n' + txt)
states = model.layers[1].states
for state in states:
if tf.__version__[0] == '2':
print(state.numpy())
else:
print(K.get_value(state))
Inspect source code:
from inspect import getsource
print(getsource(model.layers[1].reset_states))
Upvotes: 2
Reputation: 1
My understanding from this is that they are initialized to zero in most cases.
Upvotes: -1