Shamane Siriwardhana
Shamane Siriwardhana

Reputation: 4201

How can I create a lstm cell with only single time step in tensorflw?

layer_1 = tf.layers.dense(inputs=layer_c, units=512, activation=tf.nn.tanh, name='layer1')

layer_2 = tf.layers.dense(inputs=1, units=512, activation=tf.nn.tanh, name='layer2')

Here my layer_2 output is [batch_size,512]. I need to send this layer_2 output through a single lstm unit. But when I tried tf.nn.static_rnn it gives an error saying my input should be a sequence. How can I perform this task?

Upvotes: 0

Views: 283

Answers (1)

0xsx
0xsx

Reputation: 500

From the documentation for static_rnn, the inputs argument is expecting a list:

inputs: A length T list of inputs, each a Tensor of shape [batch_size, input_size], or a nested tuple of such elements.

In your case, T==1, so you can just pass it a single-element list containing your previous layer. To keep track of the internal cell and hidden states in such a way that you can keep them across timesteps, you can add additional placeholders and pass them to the static_rnn using the initial_state attribute. Because cell.state_size is a tuple for LSTM cells (of (cell_state, hidden_state)), we have to pass a tuple for this attribute, and a tuple will be returned for the output state.

Here is a minimal working example based on your code, just feeding placeholders of ones for the input at each timestep and tracking the internal states across time:

import tensorflow as tf
import numpy as np


num_timesteps = 6
batch_size = 3
num_input_feats = 100
num_lstm_units = 5

lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units)


input_x = tf.placeholder(tf.float32, [None, num_input_feats], name='input')

input_c_state = tf.placeholder(tf.float32, [None, lstm_cell.state_size.c], name='c_state')
input_h_state = tf.placeholder(tf.float32, [None, lstm_cell.state_size.h], name='h_state')

layer_1 = tf.layers.dense(inputs=input_x, units=512, activation=tf.nn.tanh, name='layer1')
layer_2 = tf.layers.dense(inputs=layer_1, units=512, activation=tf.nn.tanh, name='layer2')

layer_2_next, next_state = tf.nn.static_rnn(lstm_cell, [layer_2], dtype=tf.float32,
                                            initial_state=(input_c_state, input_h_state))

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())

  # initialize the internal cell state and hidden state to zero
  cur_c_state = np.zeros([batch_size, lstm_cell.state_size.c], dtype="float32")
  cur_h_state = np.zeros([batch_size, lstm_cell.state_size.h], dtype="float32")


  for i in range(num_timesteps):

    # here is your single timestep of input
    cur_x = np.ones([batch_size, num_input_feats], dtype="float32")

    y_out, out_state = sess.run([layer_2_next, next_state],
                                feed_dict={input_x: cur_x,
                                           input_c_state: cur_c_state,
                                           input_h_state: cur_h_state})
    cur_c_state, cur_h_state = out_state   # pass states along to the next timestep

    print (y_out)  # here is your single timestep of output

Upvotes: 1

Related Questions