tao_oat
tao_oat

Reputation: 1037

TensorFlow throws error only when using MultiRNNCell

I'm building an encoder-decoder model in TensorFlow 1.0.1 using the legacy sequence-to-sequence framework. Everything works as it should when I have one layer of LSTMs in the encoder and decoder. However, when I try with >1 layers of LSTMs wrapped in a MultiRNNCell, I get an error when calling tf.contrib.legacy_seq2seq.rnn_decoder.

The full error is at the end up this post, but in brief, it's caused by a line

(c_prev, m_prev) = state

in TensorFlow that throws TypeError: 'Tensor' object is not iterable.. I'm confused by this, since the initial state I'm passing to rnn_decoder is indeed a tuple as it should be. As far as I can tell, the only difference between using 1 or >1 layers is that the latter involves using MultiRNNCell. Are there some API quirks that I should know about when using this?

This is my code (based on the example in this GitHub repo). Apologies for its length; this is as minimal I could make it while still being complete and verifiable.

import tensorflow as tf
import tensorflow.contrib.legacy_seq2seq as seq2seq
import tensorflow.contrib.rnn as rnn

seq_len = 50
input_dim = 300
output_dim = 12
num_layers = 2
hidden_units = 100

sess = tf.Session()

encoder_inputs = []
decoder_inputs = []

for i in range(seq_len):
    encoder_inputs.append(tf.placeholder(tf.float32, shape=(None, input_dim),
                                         name="encoder_{0}".format(i)))

for i in range(seq_len + 1):
    decoder_inputs.append(tf.placeholder(tf.float32, shape=(None, output_dim),
                                         name="decoder_{0}".format(i)))

if num_layers > 1:
    # Encoder cells (bidirectional)
    # Forward
    enc_cells_fw = [rnn.LSTMCell(hidden_units)
                    for _ in range(num_layers)]
    enc_cell_fw = rnn.MultiRNNCell(enc_cells_fw)
    # Backward
    enc_cells_bw = [rnn.LSTMCell(hidden_units)
                    for _ in range(num_layers)]
    enc_cell_bw = rnn.MultiRNNCell(enc_cells_bw)
    # Decoder cell
    dec_cells = [rnn.LSTMCell(2*hidden_units)
                 for _ in range(num_layers)]
    dec_cell = rnn.MultiRNNCell(dec_cells)
else:
    # Encoder
    enc_cell_fw = rnn.LSTMCell(hidden_units)
    enc_cell_bw = rnn.LSTMCell(hidden_units)
    # Decoder
    dec_cell = rnn.LSTMCell(2*hidden_units)

# Make sure input and output are the correct dimensions
enc_cell_fw = rnn.InputProjectionWrapper(enc_cell_fw, input_dim)
enc_cell_bw = rnn.InputProjectionWrapper(enc_cell_bw, input_dim)
dec_cell = rnn.OutputProjectionWrapper(dec_cell, output_dim)

_, final_fw_state, final_bw_state = \
     rnn.static_bidirectional_rnn(enc_cell_fw,
                                  enc_cell_bw,
                                  encoder_inputs,
                                  dtype=tf.float32)

# Concatenate forward and backward cell states
# (The state is a tuple of previous output and cell state)
if num_layers == 1:
    initial_dec_state = tuple([tf.concat([final_fw_state[i],
                                          final_bw_state[i]], 1) 
                               for i in range(2)])
else:
    initial_dec_state = tuple([tf.concat([final_fw_state[-1][i],
                                          final_bw_state[-1][i]], 1) 
                               for i in range(2)])

decoder = seq2seq.rnn_decoder(decoder_inputs, initial_dec_state, dec_cell)

tf.global_variables_initializer().run(session=sess)

And this is the error:

Traceback (most recent call last):
  File "example.py", line 67, in <module>
    decoder = seq2seq.rnn_decoder(decoder_inputs, initial_dec_state, dec_cell)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py", line 150, in rnn_decoder
    output, state = cell(inp, state)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 426, in __call__
    output, res_state = self._cell(inputs, state)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 655, in __call__
    cur_inp, new_state = cell(cur_inp, cur_state)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 321, in __call__
    (c_prev, m_prev) = state
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 502, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

Thank you!

Upvotes: 0

Views: 1136

Answers (1)

Mirco Nani
Mirco Nani

Reputation: 76

The problem is in the format of the initial state (initial_dec_state) passed to seq2seq.rnn_decoder.

When you use rnn.MultiRNNCell, you're building a multilayer recurrent network, so you need to provide an initial state for each of these layers.

Hence, you should provide a list of tuples as the initial state, where each element of the list is the previous state coming from the corresponding layer of the recurrent network.

So your initial_dec_state, initialized like this:

    initial_dec_state = tuple([tf.concat([final_fw_state[-1][i],
                                      final_bw_state[-1][i]], 1) 
                           for i in range(2)])

instead should be like this:

    initial_dec_state = [
                    tuple([tf.concat([final_fw_state[j][i],final_bw_state[j][i]], 1) 
                           for i in range(2)]) for j in range(len(final_fw_state))
                        ]

which creates a list of tuples in the format:

    [(state_c1, state_m1), (state_c2, state_m2) ...]

In more detail, the 'Tensor' object is not iterable. error, happens because seq2seq.rnn_decoder internally calls your rnn.MultiRNNCell (dec_cell) passing the initial state (initial_dec_state) to it.

rnn.MultiRNNCell.__call__ iterates through the list of initial states and for each one of them extracts the tuple (c_prev, m_prev) (in the statement (c_prev, m_prev) = state).

So if you pass just a tuple, rnn.MultiRNNCell.__call__ will iterate over it, and as soon as it reaches the (c_prev, m_prev) = state it will find a tensor (which should be a tuple) as state and will throw the 'Tensor' object is not iterable. error.

A good way to know which format of initial state a seq2seq.rnn_decoder expects, is to call dec_cell.zero_state(batch_size, dtype=tf.float32). This method returns zero-filled state tensor(s) in the exact format needed to initialize the recurrent module that you're using.

Upvotes: 6

Related Questions