Jonn Dove
Jonn Dove

Reputation: 487

What does Keras do with the initial values of cell & hidden states (RNN, LSTM) for inference?

Assuming training is finished: what values does Keras use for the 0th cell state and hidden states at inference (in LSTM and RNN layers)? I could think of at least three scenarios, and could not find any conclusive answer in the documentation:

(a) The initial states are learned and then used for all predictions

(b) or the initial states are always set at zero

(c) the initial states are always random (let's hope not...?)

Upvotes: 4

Views: 1201

Answers (2)

OverLordGoldDragon
OverLordGoldDragon

Reputation: 19776

If using LSTM(stateful=True), hidden states are initialized to zero, change with fit or predict, and are kept at whatever they are until .reset_states() is called. If LSTM(stateful=False), states are reset after fitting/predicting/etc each batch.

This can be verified from the .reset_states() source code, and by direct inspection; both for stateful=True below. For more info on how states are passed, see this answer.


Direct inspection:

batch_shape = (2, 10, 4)
model = make_model(batch_shape)

X = np.random.randn(*batch_shape)
y = np.random.randint(0, 2, (batch_shape[0], 1))

show_lstm_states("STATES INITIALIZED")
model.train_on_batch(X, y)

show_lstm_states("STATES AFTER TRAIN")
model.reset_states()
show_lstm_states("STATES AFTER RESET")

model.predict(X)
show_lstm_states("STATES AFTER PREDICT")

Output:

STATES INITIALIZED
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]]
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]]

STATES AFTER TRAIN
[[0.12061571 0.03639204 0.20810013 0.05309075]
 [0.01832913 0.00062357 0.10566339 0.60108346]]
[[0.21241754 0.0773523  0.37392718 0.15590034]
 [0.08496398 0.00112716 0.23814857 0.95995367]]

STATES AFTER RESET
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]]
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]]

STATES AFTER PREDICT
[[0.12162527 0.03720453 0.20628096 0.05421837]
 [0.01849432 0.00064993 0.1045063  0.6097021 ]]
[[0.21398112 0.07894284 0.3709934  0.15928769]
 [0.08605779 0.00117485 0.23606434 0.97212094]]

Functions / imports used:

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense, LSTM
from tensorflow.keras.models import Model
import numpy as np

def make_model(batch_shape):
    ipt = Input(batch_shape=batch_shape)
    x   = LSTM(4, stateful=True, activation='relu')(ipt)
    out = Dense(1, activation='sigmoid')(x)

    model = Model(ipt, out)
    model.compile('adam', 'binary_crossentropy')

    return model

def show_lstm_states(txt=''):
    print('\n' + txt) 
    states = model.layers[1].states

    for state in states:
        if tf.__version__[0] == '2':
            print(state.numpy())
        else:
            print(K.get_value(state))

Inspect source code:

from inspect import getsource
print(getsource(model.layers[1].reset_states))

Upvotes: 2

Carson Cummins
Carson Cummins

Reputation: 1

My understanding from this is that they are initialized to zero in most cases.

Upvotes: -1

Related Questions