Reputation: 4201
layer_1 = tf.layers.dense(inputs=layer_c, units=512, activation=tf.nn.tanh, name='layer1')
layer_2 = tf.layers.dense(inputs=1, units=512, activation=tf.nn.tanh, name='layer2')
Here my layer_2 output is [batch_size,512]. I need to send this layer_2 output through a single lstm unit. But when I tried tf.nn.static_rnn it gives an error saying my input should be a sequence. How can I perform this task?
Upvotes: 0
Views: 283
Reputation: 500
From the documentation for static_rnn
, the inputs
argument is expecting a list:
inputs
: A length T list of inputs, each a Tensor of shape [batch_size, input_size], or a nested tuple of such elements.
In your case, T==1
, so you can just pass it a single-element list containing your previous layer. To keep track of the internal cell and hidden states in such a way that you can keep them across timesteps, you can add additional placeholders and pass them to the static_rnn
using the initial_state
attribute. Because cell.state_size
is a tuple for LSTM cells (of (cell_state, hidden_state)
), we have to pass a tuple for this attribute, and a tuple will be returned for the output state.
Here is a minimal working example based on your code, just feeding placeholders of ones for the input at each timestep and tracking the internal states across time:
import tensorflow as tf
import numpy as np
num_timesteps = 6
batch_size = 3
num_input_feats = 100
num_lstm_units = 5
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units)
input_x = tf.placeholder(tf.float32, [None, num_input_feats], name='input')
input_c_state = tf.placeholder(tf.float32, [None, lstm_cell.state_size.c], name='c_state')
input_h_state = tf.placeholder(tf.float32, [None, lstm_cell.state_size.h], name='h_state')
layer_1 = tf.layers.dense(inputs=input_x, units=512, activation=tf.nn.tanh, name='layer1')
layer_2 = tf.layers.dense(inputs=layer_1, units=512, activation=tf.nn.tanh, name='layer2')
layer_2_next, next_state = tf.nn.static_rnn(lstm_cell, [layer_2], dtype=tf.float32,
initial_state=(input_c_state, input_h_state))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# initialize the internal cell state and hidden state to zero
cur_c_state = np.zeros([batch_size, lstm_cell.state_size.c], dtype="float32")
cur_h_state = np.zeros([batch_size, lstm_cell.state_size.h], dtype="float32")
for i in range(num_timesteps):
# here is your single timestep of input
cur_x = np.ones([batch_size, num_input_feats], dtype="float32")
y_out, out_state = sess.run([layer_2_next, next_state],
feed_dict={input_x: cur_x,
input_c_state: cur_c_state,
input_h_state: cur_h_state})
cur_c_state, cur_h_state = out_state # pass states along to the next timestep
print (y_out) # here is your single timestep of output
Upvotes: 1