Reputation: 12175
I came across the following example and I did not know it was possible to feed an RNN state as follows.
self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.initial_state = cell.zero_state(args.batch_size, tf.float32)
In this segment of code the initial state is declared as a zeroed state. To my knowledge this is not a placeholder. It is just a tupel of zero tensors.
Then in the function where the RNN model is used to generate the initial state is fed in session.run.
def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1):
state = sess.run(self.cell.zero_state(1, tf.float32))
for char in prime[:-1]:
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
feed = {self.input_data: x, self.initial_state:state}
[state] = sess.run([self.final_state], feed)
Since self.initial_state is not a placeholder how can it be fed win session.run?
Here is a link to the code I was looking at.
Upvotes: 2
Views: 627
Reputation: 11
I came across the same question as you when reading the similar RNN code.
From my understanding the rnn_cell.zero_state
actually returns you a tuple of tensors, which are feed-able. Your placeholders are also tensors.
So if you do:
print(init_state[0])
# You will get something like
<tf.Tensor 'LSTM_cell/initial_state/BasicLSTMCellZeroState/zeros:0' shape=(50, 10) dtype=float32>
And feed dict allows you to feed as long as it's a tensor or array of tensors.
Upvotes: 0
Reputation: 25463
Note that you can feed in any variable, not just placeholders. So in this case, you can feed in each component of the tuple manually:
feed = {
self.input_data: x,
self.initial_state[0]: state[0],
self.initial_state[1]: state[1]
}
Upvotes: 2