Tobbey
Tobbey

Reputation: 495

Side effect in tf.while_loop

I am currently having a hard time trying to understand how tensorflow works, and I feel like the python interface is somehow obscure.

I recently tried to run a simple print statement inside a tf.while_loop, and there are many things that remains unclear to me:

import tensorflow as tf

nb_iter = tf.constant(value=10)
#This solution does not work at all
#nb_iter = tf.get_variable('nb_iter', shape=(1), dtype=tf.int32, trainable=False)
i = tf.get_variable('i', shape=(), trainable=False,
                     initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)

loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
    tf.Print(i, [i], message='Another iteration')
    return [tf.add(i, 1)]

i = tf.while_loop(loop_condition, loop_body, [i])

initializer_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(initializer_op)
    res = sess.run(i)
    print('res is now {}'.format(res))

Notice that if I initialize nb_iter with

nb_iter = tf.get_variable('nb_iter', shape=(1), dtype=tf.int32, trainable=False)

I got the following error:

ValueError: Shape must be rank 0 but is rank 1 for 'while/LoopCond' (op: 'LoopCond') with input shapes: [1].

It get even worse when I try to use the 'i' index for indexing a tensor (example not shown here), I then get the following error

alueError: Operation 'while/strided_slice' has been marked as not fetchable.

Can someone point me to a documentation that explains how tf.while_loop works when used with tf.Variables, and if it possible to use side_effects (like print) inside the loop, as well as indexing tensor with the loop variable ?

Thank you in advance for your help

Upvotes: 2

Views: 623

Answers (1)

Tobbey
Tobbey

Reputation: 495

There were actually many things wrong with my first example:

tf.Print is not executed if the operator has no side effect (ie i = tf.Print())

If the boolean, is a scalar, it is then a rank-0 tensor, not a rank-1 tensor. ...

Here is the code that works:

import tensorflow as tf

#nb_iter = tf.constant(value=10)
#This solution does not work at all
nb_iter = tf.get_variable('nb_iter', shape=(), dtype=tf.int32, trainable=False,
                          initializer=tf.zeros_initializer())
nb_iter = tf.add(nb_iter,10)
i = tf.get_variable('i', shape=(), trainable=False,
                     initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)
v = tf.get_variable('v', shape=(10), trainable=False,
                     initializer=tf.random_uniform_initializer, dtype=tf.float32)

loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
    i = tf.Print(i, [v[i]], message='Another vector element: ')
    return [tf.add(i, 1)]

i = tf.while_loop(loop_condition, loop_body, [i])

initializer_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(initializer_op)
    res = sess.run(i)
    print('res is now {}'.format(res))

output:

Another vector element: [0.203766704]
Another vector element: [0.692927241]
Another vector element: [0.732221603]
Another vector element: [0.0556482077]
Another vector element: [0.422092319]
Another vector element: [0.597698212]
Another vector element: [0.92387116]
Another vector element: [0.590101123]
Another vector element: [0.741415381]
Another vector element: [0.514917374]
res is now 10

Upvotes: 1

Related Questions