suku
suku

Reputation: 10929

AttributeError: 'tensorflow.python.ops.rnn' has no attribute 'rnn'

I am following this tutorial on Recurrent Neural Networks.

This is the imports:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.ops import rnn
from tensorflow.contrib.rnn import core_rnn_cell

This is code for input processing:

x = tf.transpose(x, [1,0,2])
x = tf.reshape(x, [-1, chunk_size])
x = tf.split(x, n_chunks, 0)

lstm_cell = core_rnn_cell.BasicLSTMCell(rnn_size)
outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

I am getting the following error for the outputs, states:

AttributeError: module 'tensorflow.python.ops.rnn' has no attribute 'rnn'

TensorFlow was updated recently, so what should be the new code for the offending line

Upvotes: 5

Views: 18880

Answers (3)

Basma Elshoky
Basma Elshoky

Reputation: 149

Use static_rnn method instead of rnn.

 outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

instead of:

 outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

Upvotes: 0

suku
suku

Reputation: 10929

For people using the newer version of tensorflow, add this to the code:

from tensorflow.contrib import rnn 


lstm_cell = rnn.BasicLSTMCell(rnn_size) 
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

instead of

from tensorflow.python.ops import rnn, rnn_cell 
lstm_cell = rnn_cell.BasicLSTMCell(rnn_size,state_is_tuple=True) 
outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

PS: @BrendanA suggested that tf.nn.rnn_cell.LSTMCell be used instead of rnn_cell.BasicLSTMCell

Upvotes: 26

Thanks @suku

I get the following error: ImportError: No module named 'tensorflow.contrib.rnn.python.ops.core_rnn'

To solve:

from tensorflow.contrib.rnn.python.ops import core_rnn

replaced by:

from tensorflow.python.ops import rnn, rnn_cell

and in my code I had used core_rnn.static_rnn:

 outputs,_ = core_rnn.static_rnn(cell, input_list, dtype=tf.float32)

I got the this error:

NameError: name 'core_rnn' is not defined

This is solved by replacing the line by:

outputs,_ = rnn.static_rnn(cell, input_list, dtype=tf.float32)

python: 3.6 64bit rensorflow:1.10.0

Upvotes: 0

Related Questions