NeuroEng
NeuroEng

Reputation: 191

Stateful RNN (LSTM) in keras

imagin the following the data:

X = [x1, x2, x3, x4, x5, x6, ...]

and

Y = [y1, y2, y3, y4, ...]

the label represent the input in the following manner:

[x1,x2] -> y1
[x2,x3] -> y2
.
.
.

I am trying to make a model in using keras, so that when the classification takes place, the model remembers what it classified the previous stage to be, and make it causal as in the next prediction is directly dependent on the previous one, somewhat similar to other methods like HMM. So something like this:

Y2 = f( [x2,x3] , y1)

I have read this page, where they divide each batch into sub-batches (if that's the correct term?) and reset state between each main batch, but what I want to do is not shuffle the batches and introduce that causality into the model.

My question is how can you do this with stateful LSTMs?

Upvotes: 0

Views: 208

Answers (1)

Jirayu Kaewprateep
Jirayu Kaewprateep

Reputation: 766

One way is to do custom layer inherits from the LSTM class

[ Sample ]:

import tensorflow as tf

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Class / Definition
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
class MyLSTMLayer( tf.keras.layers.LSTM ):
    def __init__(self, units, return_sequences, return_state):
        super(MyLSTMLayer, self).__init__( units, return_sequences=True, return_state=False )
        self.num_units = units

    def build(self, input_shape):
        self.kernel = self.add_weight("kernel",
        shape=[int(input_shape[-1]),
        self.num_units])

    def call(self, inputs):
        return tf.matmul(inputs, self.kernel)

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Variables
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
start = 3
limit = 12
delta = 3
sample = tf.range( start, limit, delta )
sample = tf.cast( sample, dtype=tf.float32 )
sample = tf.constant( sample, shape=( sample.shape[0], 1, 1 ) )
layer = MyLSTMLayer( sample.shape[0], True, False )

print( sample )                     
print( layer(sample) )              

[ Output ]:

tf.Tensor(
[[[3.]]

 [[6.]]

 [[9.]]], shape=(3, 1, 1), dtype=float32)
tf.Tensor(
[[[-1.8635211  2.6157026 -1.6650987]]

 [[-3.7270422  5.2314053 -3.3301973]]

 [[-5.5905633  7.8471084 -4.995296 ]]], shape=(3, 1, 3), dtype=float32)

Upvotes: 1

Related Questions