Reputation: 495
I am currently having a hard time trying to understand how tensorflow works, and I feel like the python interface is somehow obscure.
I recently tried to run a simple print statement inside a tf.while_loop, and there are many things that remains unclear to me:
import tensorflow as tf
nb_iter = tf.constant(value=10)
#This solution does not work at all
#nb_iter = tf.get_variable('nb_iter', shape=(1), dtype=tf.int32, trainable=False)
i = tf.get_variable('i', shape=(), trainable=False,
initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)
loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
tf.Print(i, [i], message='Another iteration')
return [tf.add(i, 1)]
i = tf.while_loop(loop_condition, loop_body, [i])
initializer_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(initializer_op)
res = sess.run(i)
print('res is now {}'.format(res))
Notice that if I initialize nb_iter with
nb_iter = tf.get_variable('nb_iter', shape=(1), dtype=tf.int32, trainable=False)
I got the following error:
ValueError: Shape must be rank 0 but is rank 1 for 'while/LoopCond' (op: 'LoopCond') with input shapes: [1].
It get even worse when I try to use the 'i' index for indexing a tensor (example not shown here), I then get the following error
alueError: Operation 'while/strided_slice' has been marked as not fetchable.
Can someone point me to a documentation that explains how tf.while_loop works when used with tf.Variables, and if it possible to use side_effects (like print) inside the loop, as well as indexing tensor with the loop variable ?
Thank you in advance for your help
Upvotes: 2
Views: 623
Reputation: 495
There were actually many things wrong with my first example:
tf.Print is not executed if the operator has no side effect (ie i = tf.Print())
If the boolean, is a scalar, it is then a rank-0 tensor, not a rank-1 tensor. ...
Here is the code that works:
import tensorflow as tf
#nb_iter = tf.constant(value=10)
#This solution does not work at all
nb_iter = tf.get_variable('nb_iter', shape=(), dtype=tf.int32, trainable=False,
initializer=tf.zeros_initializer())
nb_iter = tf.add(nb_iter,10)
i = tf.get_variable('i', shape=(), trainable=False,
initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)
v = tf.get_variable('v', shape=(10), trainable=False,
initializer=tf.random_uniform_initializer, dtype=tf.float32)
loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
i = tf.Print(i, [v[i]], message='Another vector element: ')
return [tf.add(i, 1)]
i = tf.while_loop(loop_condition, loop_body, [i])
initializer_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(initializer_op)
res = sess.run(i)
print('res is now {}'.format(res))
output:
Another vector element: [0.203766704]
Another vector element: [0.692927241]
Another vector element: [0.732221603]
Another vector element: [0.0556482077]
Another vector element: [0.422092319]
Another vector element: [0.597698212]
Another vector element: [0.92387116]
Another vector element: [0.590101123]
Another vector element: [0.741415381]
Another vector element: [0.514917374]
res is now 10
Upvotes: 1