Reputation: 877
I'm trying apply a while loop over a tensor's value. For example for variable "a" I am trying to increase the tensor's value incrementally till a certain condition is met. However, I keep getting this error:
ValueError: Shape must be rank 0 but is rank 3 for 'while_12/LoopCond' (op: 'LoopCond') with input shapes: [3,1,1].
a = array([[[0.76393723]],
[[0.93270312]],
[[0.08361106]]])
a = np.random.random((3,1,1))
a1 = tf.constant(np.float64(a))
i = tf.constant(np.float64(6.14))
c = lambda i: tf.less(i, a1)
b = lambda x: tf.add(x, 0.1)
r = tf.while_loop(c, b, [a1])
Upvotes: 2
Views: 1007
Reputation: 4757
The first argument of the tf.while_loop() should return scalar (the tensor of rank 0 is, actually, a scalar - that's what the error message is about). In your example you probably want to make the condition return true
in case if all the numbers in the a1
tensor are less than 6.14
. This can be achieved by tf.reduce_all() (logical AND) and tf.reduce_any() (logical OR).
That snippet has worked for me:
tf.reset_default_graph()
a = np.random.random_integers(3, size=(3,2))
print(a)
# [[1 1]
# [2 3]
# [1 1]]
a1 = tf.constant(a)
i = 6
# condition returns True till any number in `x` is less than 6
condition = lambda x : tf.reduce_any(tf.less(x, i))
body = lambda x : tf.add(x, 1)
loop = tf.while_loop(
condition,
body,
[a1],
)
with tf.Session() as sess:
result = sess.run(loop)
print(result)
# [[6 6]
# [7 8]
# [6 6]]
# All numbers now are greater than 6
Upvotes: 4