gusto
gusto

Reputation: 324

Initial states in Keras LSTM

I've been struggling to understand exactly when the hidden_state is reinitialized in my Keras LSTM model when stateful=False. Various tutorials I've seen imply that it is reset at the beginning of each batch, but from what I can tell, it is actually reset between each sample in a batch. Am I wrong?

I've written the following code to test this out:

from keras.models import Sequential
from keras.layers import Dense, LSTM
import keras.backend as K
import numpy as np
import tensorflow as tf

a = [1, 0, 0]
b = [0, 1, 0]
c = [0, 0, 1]

seq = [a, b, c, b, a]

x = seq[:-1]
y = seq[1:]
window_size = 1

x = np.array(x).reshape((len(x), window_size , 3))
y = np.array(y)

def run_with_batch_size(batch_size=1):
  model = Sequential()
  model.add(LSTM(20, input_shape=(1, 3)))
  model.add(Dense(3, activation='softmax'))
  model.compile(loss='mean_squared_error', optimizer='adam')

  for i in range(500):
    model.fit(x, y,
      batch_size=batch_size,
      epochs=1,
      verbose=0,
      shuffle=False
    )

  print(model.predict(np.array([[a], [b]]), batch_size=batch_size))
  print()
  print(model.predict(np.array([[b], [c]]), batch_size=batch_size))
  print()
  print(model.predict(np.array([[c], [b]]), batch_size=batch_size))


print('-'*30)
run_with_batch_size(1)
print('**')
run_with_batch_size(2)

The result of running this code:

------------------------------
# batch_size 1
[[0.01296294 0.9755857  0.01145133]
 [0.48558792 0.02751653 0.4868956 ]]

[[0.48558792 0.02751653 0.4868956 ]
 [0.01358072 0.9738273  0.01259203]]

[[0.01358072 0.9738273  0.01259203]
 [0.48558792 0.02751653 0.4868956 ]]
**
# batch_size 2
# output of batch (a, b)
[[0.0255649  0.94444686 0.02998832]
 [0.47172785 0.05804421 0.47022793]]

# output of batch (b, c)
# notice first output here is the same as the second output from above
[[0.47172785 0.05804421 0.47022793]
 [0.03059724 0.93813574 0.03126698]]

[[0.03059724 0.93813574 0.03126698]
 [0.47172785 0.05804421 0.47022793]]
------------------------------

When my batch_size is 1:

When my batch_size is 2:

I'm still pretty fresh to this area, so it's very possible that I'm misunderstanding something. Is the initial state reset between each sample in a batch rather than between each batch?

Upvotes: 2

Views: 4096

Answers (1)

nuric
nuric

Reputation: 11225

Great testing and you are on the right track. To answer the question directly, the initial state is set for every sample in the batch at every forward pass when stateful=False. Following the source code:

def get_initial_state(self, inputs):
  # build an all-zero tensor of shape (samples, output_dim)
  initial_state = K.zeros_like(inputs)  # (samples, timesteps, input_dim)
  initial_state = K.sum(initial_state, axis=(1, 2))  # (samples,)
  initial_state = K.expand_dims(initial_state)  # (samples, 1)
  # ...

This means every sample in the batch gets a clean initial state of zeros. Using this function is in the call function:

if initial_state is not None:
  pass
elif self.stateful:
  initial_state = self.states
else:
  initial_state = self.get_initial_state(inputs)

So if stateful=False and you have not provided any explicit initial_states, the code will create fresh initial states for the RNN including an LSTM which inherits from the RNN layer. Now since call is responsible for computing the forward pass, every time there is a forward pass which is computed in batches as you have discovered you will get new initial states.

Upvotes: 4

Related Questions