als5ev
als5ev

Reputation: 315

Having trouble understanding lstm use in tensorflow code sample

Why is the pred variable being calculated before any of the training iterations occur? I would expect that a pred would be generated (through the RNN() function) during each pass through of the data for every iteration?

There must be something I am missing. Is pred something like a function object? I have looked at the docs for tf.matmul() and that returns a tensor, not a function.

Full source: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/recurrent_network.py

Here is the code:

def RNN(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, n_steps, n_input)
    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)

    # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    x = tf.unstack(x, n_steps, 1)

    # Define a lstm cell with tensorflow
    lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)

    # Get lstm cell output
    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    return tf.matmul(outputs[-1], weights['out']) + biases['out']

pred = RNN(x, weights, biases)

# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

# Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initializing the variables
init = tf.global_variables_initializer()

Upvotes: 1

Views: 244

Answers (1)

David Parks
David Parks

Reputation: 32111

Tensorflow code has two distinct phases. First, you build a "dependency graph", which contains all of the operations that you will use. Note that during this phase you are not processing any data. Instead, you are simply defining the operations you want to occur. Tensorflow is taking note of the dependencies between the operations.

For example, in order to compute the accuracy, you'll need to first compute correct_pred, and to compute correct_pred you'll need to first compute pred, and so on.

So all you have done in the code shown is to tell tensorflow what operations you want. You've saved those in a "graph" data structure (that's a tensorflow data structure that basically is a bucket that contains all the mathematical operations and tensors).

Later you will run operations on the data using calls to sess.run([ops], feed_dict={inputs}).

When you call sess.run notice that you have to tell it what you want from the graph. If you ask for accuracy:

   sess.run(accuracy, feed_dict={inputs})

Tensorflow will try to compute accuracy. It will see that accuracy depends on correct_pred, so it will try to compute that, and so on through the dependency graph that you defined.

The error you're making is that you think pred in the code you listed is computing something. It's not. The line:

   pred = RNN(x, weights, biases)

only defined the operation and its dependencies.

Upvotes: 1

Related Questions