chasep255
chasep255

Reputation: 12175

Tensorflow feeding initial RNN state

I came across the following example and I did not know it was possible to feed an RNN state as follows.

self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.initial_state = cell.zero_state(args.batch_size, tf.float32)

In this segment of code the initial state is declared as a zeroed state. To my knowledge this is not a placeholder. It is just a tupel of zero tensors.

Then in the function where the RNN model is used to generate the initial state is fed in session.run.

def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1):
    state = sess.run(self.cell.zero_state(1, tf.float32))
    for char in prime[:-1]:
        x = np.zeros((1, 1))
        x[0, 0] = vocab[char]
        feed = {self.input_data: x, self.initial_state:state}
        [state] = sess.run([self.final_state], feed)

Since self.initial_state is not a placeholder how can it be fed win session.run?

Here is a link to the code I was looking at.

Upvotes: 2

Views: 627

Answers (2)

Allen
Allen

Reputation: 11

I came across the same question as you when reading the similar RNN code.

From my understanding the rnn_cell.zero_state actually returns you a tuple of tensors, which are feed-able. Your placeholders are also tensors.

So if you do:

print(init_state[0])

# You will get something like 
<tf.Tensor 'LSTM_cell/initial_state/BasicLSTMCellZeroState/zeros:0' shape=(50, 10) dtype=float32>

And feed dict allows you to feed as long as it's a tensor or array of tensors.

Upvotes: 0

Casey Chu
Casey Chu

Reputation: 25463

Note that you can feed in any variable, not just placeholders. So in this case, you can feed in each component of the tuple manually:

feed = {
    self.input_data: x, 
    self.initial_state[0]: state[0], 
    self.initial_state[1]: state[1]
}

Upvotes: 2

Related Questions