Reputation: 22714
Tensorflow newbie here! I understand that Variables will be trained over time, placeholders are used input data that doesn't change as your model trains (like input images, and class labels for those images).
I'm trying to implement the forward propagation of RNN using Tensorflow, and wondering on what type I should save the output of the RNN cell. In numpy RNN implementation, it uses
hiddenStates = np.zeros((T, self.hidden_dim)) #T is the length of the sequence
Then it iteratively saves the output in the np.zeros array.
In case of TF, which one should I use, tf.zeros or tf.placeholder?
What is the best practice in this case? I think it should be fine to use tf.zeros but wanted to double check.
Upvotes: 1
Views: 1031
Reputation: 1104
First of all, it is important to you to understand that everything inside Tensorflow is a Tensor. So when you are performing some kind of computation (e.g. an rnn implementation like outputs = rnn(...)
) the output of this computation is returned as a Tensor. So you don't need to store it inside any kind of structure. You can retrieve it by running the correspondent node (i.e. output
) like session.run(output, feed_dict)
.
Told this, I think you need to take the final state of an RNN and provide it as initial state of a subsequent computation. Two ways:
A) If you are using RNNCell
implementations During the construction of your model you can construct the zero state like this:
cell = (some RNNCell implementation)
initial_state = cell.zero_state(batch_size, tf.float32)
B) If you are uimplementing your own staff Define the state as a zero Tensor:
initial_state = tf.zeros([batch_size, hidden_size])
Then, in both cases you will have something like:
output, final_state = rnn(input, initial_state)
In your execution loop you can initialize your state first and then provide the final_state
as initial_state
inside your feed_dict
:
state = session.run(initial_state)
for step in range(epochs):
feed_dict = {initial_state: state}
_, state = session.run((train_op,final_state), feed_dict)
How you actually construct your feed_dict
depends on the implementation of the RNN.
For an BasicLSTMCell
, for example, a state is an LSTMState
object and you need to provide both c
and h
:
feed_dict = {initial_state.c=state.c, initial_state.h: state.h}
Upvotes: 2