aerin
aerin

Reputation: 22714

tf.zeros vs tf.placeholder as RNN initial state

Tensorflow newbie here! I understand that Variables will be trained over time, placeholders are used input data that doesn't change as your model trains (like input images, and class labels for those images).

I'm trying to implement the forward propagation of RNN using Tensorflow, and wondering on what type I should save the output of the RNN cell. In numpy RNN implementation, it uses

hiddenStates = np.zeros((T, self.hidden_dim)) #T is the length of the sequence

Then it iteratively saves the output in the np.zeros array.

In case of TF, which one should I use, tf.zeros or tf.placeholder?

What is the best practice in this case? I think it should be fine to use tf.zeros but wanted to double check.

Upvotes: 1

Views: 1031

Answers (1)

Giuseppe Marra
Giuseppe Marra

Reputation: 1104

First of all, it is important to you to understand that everything inside Tensorflow is a Tensor. So when you are performing some kind of computation (e.g. an rnn implementation like outputs = rnn(...)) the output of this computation is returned as a Tensor. So you don't need to store it inside any kind of structure. You can retrieve it by running the correspondent node (i.e. output) like session.run(output, feed_dict).

Told this, I think you need to take the final state of an RNN and provide it as initial state of a subsequent computation. Two ways:

A) If you are using RNNCell implementations During the construction of your model you can construct the zero state like this:

cell = (some RNNCell implementation)
initial_state = cell.zero_state(batch_size, tf.float32)

B) If you are uimplementing your own staff Define the state as a zero Tensor:

initial_state = tf.zeros([batch_size, hidden_size])

Then, in both cases you will have something like:

output, final_state = rnn(input, initial_state)

In your execution loop you can initialize your state first and then provide the final_state as initial_stateinside your feed_dict:

state = session.run(initial_state)
for step in range(epochs):

   feed_dict = {initial_state: state}
   _, state = session.run((train_op,final_state), feed_dict)

How you actually construct your feed_dict depends on the implementation of the RNN.

For an BasicLSTMCell, for example, a state is an LSTMState object and you need to provide both c and h:

feed_dict = {initial_state.c=state.c, initial_state.h: state.h}

Upvotes: 2

Related Questions