Leafar
Leafar

Reputation: 374

How to extract cell state from a LSTM at each timestep in Keras?

Is there a way in Keras to retrieve the cell state (i.e., c vector) of a LSTM layer at every timestep of a given input?

It seems the return_state argument returns the last cell state after the computation is done, but I need also the intermediate ones. Also, I don't want to pass these cell states to the next layer, I only want to be able to access them.

Preferably using TensorFlow as backend.

Thanks

Upvotes: 12

Views: 2991

Answers (4)

Chau Luu
Chau Luu

Reputation: 41

I was looking for a solution to this issue and after reading the guidance for creating your own custom RNN Cell in tf.keras (https://www.tensorflow.org/api_docs/python/tf/keras/layers/AbstractRNNCell), I believe the following is the most concise and easy to read way of doing this for Tensorflow 2:

import tensorflow as tf
from tensorflow.keras.layers import LSTMCell

class LSTMCellReturnCellState(LSTMCell):

    def call(self, inputs, states, training=None):
        real_inputs = inputs[:,:self.units] # decouple [h, c]
        outputs, [h,c] = super().call(real_inputs, states, training=training)
        return tf.concat([h, c], axis=1), [h,c]



num_units = 512
test_input = tf.random.uniform([5,100,num_units])

rnn = tf.keras.layers.RNN(LSTMCellReturnCellState(num_units),
                          return_sequences=True, return_state=True)

whole_seq_output, final_memory_state, final_carry_state = rnn(test_input)

print(whole_seq_output.shape)
>>> (5,100,1024)

# Hidden state sequence
h_seq = whole_seq_output[:,:,:num_units] # (5,100,512)

# Cell state sequence
c_seq = whole_seq_output[:,:,num_units:] # (5,100,512)

As mentioned in an above solution, you can see the advantage of this is that it can be easily wrapped into tf.keras.layers.RNN as a drop-in for the normal LSTMCell.

Here is a Colab Notebook with the code running as expected for tensorflow==2.6.0

Upvotes: 4

Taw
Taw

Reputation: 429

First, this is not possible do with the tf.keras.layers.LSTM. You have to use LSTMCell instead or subclass LSTM. Second, there is no need to subclass LSTMCell to get the sequence of cell states. LSTMCell already returns a list of the hidden state (h) and cell state (c) everytime you call it. For those not familiar with LSTMCell, it takes in the current [h, c] tensors, and the input at the current timestep (it cannot take in a sequence of times) and returns the activations, and the updated [h,c]. Here is an example of showing how to use LSTMCell to process a sequence of timesteps and to return the accumulated cell states.

# example inputs
inputs = tf.convert_to_tensor(np.random.rand(3, 4), dtype='float32')  # 3 timesteps, 4 features
h_c = [tf.zeros((1,2)),  tf.zeros((1,2))]  # must initialize hidden/cell state for lstm cell
h_c = tf.convert_to_tensor(h_c, dtype='float32')
lstm = tf.keras.layers.LSTMCell(2)

# example of how you accumulate cell state over repeated calls to LSTMCell
inputs = tf.unstack(inputs, axis=0)
c_states = []
for cur_inputs in inputs:
    out, h_c = lstm(tf.expand_dims(cur_inputs, axis=0), h_c)
    h, c = h_c
    c_states.append(c)

Upvotes: 1

zimmerrol
zimmerrol

Reputation: 4951

You can access the states of any RNN by setting return_sequences = True in the initializer. You can find more information about this parameter here.

Upvotes: -3

Ronakrit W.
Ronakrit W.

Reputation: 693

I know it's pretty late, I hope this can help.

what you are asking, technically, is possible by modifying the LSTM-cell in call method. I modify it and make it return 4 dimension instead of 3 when you give return_sequences=True.

Code

from keras.layers.recurrent import _generate_dropout_mask
class Mod_LSTMCELL(LSTMCell):
    def call(self, inputs, states, training=None):
        if 0 < self.dropout < 1 and self._dropout_mask is None:
            self._dropout_mask = _generate_dropout_mask(
                K.ones_like(inputs),
                self.dropout,
                training=training,
                count=4)
        if (0 < self.recurrent_dropout < 1 and
                self._recurrent_dropout_mask is None):
            self._recurrent_dropout_mask = _generate_dropout_mask(
                K.ones_like(states[0]),
                self.recurrent_dropout,
                training=training,
                count=4)

            # dropout matrices for input units
        dp_mask = self._dropout_mask
        # dropout matrices for recurrent units
        rec_dp_mask = self._recurrent_dropout_mask

        h_tm1 = states[0]  # previous memory state
        c_tm1 = states[1]  # previous carry state

        if self.implementation == 1:
            if 0 < self.dropout < 1.:
                inputs_i = inputs * dp_mask[0]
                inputs_f = inputs * dp_mask[1]
                inputs_c = inputs * dp_mask[2]
                inputs_o = inputs * dp_mask[3]
            else:
                inputs_i = inputs
                inputs_f = inputs
                inputs_c = inputs
                inputs_o = inputs
            x_i = K.dot(inputs_i, self.kernel_i)
            x_f = K.dot(inputs_f, self.kernel_f)
            x_c = K.dot(inputs_c, self.kernel_c)
            x_o = K.dot(inputs_o, self.kernel_o)
            if self.use_bias:
                x_i = K.bias_add(x_i, self.bias_i)
                x_f = K.bias_add(x_f, self.bias_f)
                x_c = K.bias_add(x_c, self.bias_c)
                x_o = K.bias_add(x_o, self.bias_o)

            if 0 < self.recurrent_dropout < 1.:
                h_tm1_i = h_tm1 * rec_dp_mask[0]
                h_tm1_f = h_tm1 * rec_dp_mask[1]
                h_tm1_c = h_tm1 * rec_dp_mask[2]
                h_tm1_o = h_tm1 * rec_dp_mask[3]
            else:
                h_tm1_i = h_tm1
                h_tm1_f = h_tm1
                h_tm1_c = h_tm1
                h_tm1_o = h_tm1
            i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
                                                      self.recurrent_kernel_i))
            f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
                                                      self.recurrent_kernel_f))
            c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
                                                            self.recurrent_kernel_c))
            o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
                                                      self.recurrent_kernel_o))
        else:
            if 0. < self.dropout < 1.:
                inputs *= dp_mask[0]
            z = K.dot(inputs, self.kernel)
            if 0. < self.recurrent_dropout < 1.:
                h_tm1 *= rec_dp_mask[0]
            z += K.dot(h_tm1, self.recurrent_kernel)
            if self.use_bias:
                z = K.bias_add(z, self.bias)

            z0 = z[:, :self.units]
            z1 = z[:, self.units: 2 * self.units]
            z2 = z[:, 2 * self.units: 3 * self.units]
            z3 = z[:, 3 * self.units:]

            i = self.recurrent_activation(z0)
            f = self.recurrent_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.recurrent_activation(z3)

        h = o * self.activation(c)
        if 0 < self.dropout + self.recurrent_dropout:
            if training is None:
                h._uses_learning_phase = True
        return tf.expand_dims(tf.concat([h,c],axis=0),0), [h, c]

Sample code

# create a cell
test = Mod_LSTMCELL(100)

# Input timesteps=10, features=7
in1 = Input(shape=(10,7))
out1 = RNN(test, return_sequences=True)(in1)

M = Model(inputs=[in1],outputs=[out1])
M.compile(keras.optimizers.Adam(),loss='mse')

ans = M.predict(np.arange(7*10,dtype=np.float32).reshape(1, 10, 7))

print(ans.shape)
# state_h
print(ans[0,0,0,:])
# state_c
print(ans[0,0,1,:])

Upvotes: 2

Related Questions