MasterWizard
MasterWizard

Reputation: 877

Tensorflow: while loop on tensor

I'm trying apply a while loop over a tensor's value. For example for variable "a" I am trying to increase the tensor's value incrementally till a certain condition is met. However, I keep getting this error:

ValueError: Shape must be rank 0 but is rank 3 for 'while_12/LoopCond' (op: 'LoopCond') with input shapes: [3,1,1].

a = array([[[0.76393723]],
       [[0.93270312]],
       [[0.08361106]]])

a = np.random.random((3,1,1))
a1 = tf.constant(np.float64(a))
i = tf.constant(np.float64(6.14))

c = lambda i: tf.less(i, a1)
b = lambda x: tf.add(x, 0.1)
r = tf.while_loop(c, b, [a1])

Upvotes: 2

Views: 1007

Answers (1)

Vlad-HC
Vlad-HC

Reputation: 4757

The first argument of the tf.while_loop() should return scalar (the tensor of rank 0 is, actually, a scalar - that's what the error message is about). In your example you probably want to make the condition return true in case if all the numbers in the a1 tensor are less than 6.14. This can be achieved by tf.reduce_all() (logical AND) and tf.reduce_any() (logical OR).

That snippet has worked for me:

tf.reset_default_graph()

a = np.random.random_integers(3, size=(3,2))
print(a)
# [[1 1]
#  [2 3]
#  [1 1]]

a1 = tf.constant(a)
i = 6

# condition returns True till any number in `x` is less than 6
condition = lambda x : tf.reduce_any(tf.less(x, i))
body      = lambda x : tf.add(x, 1)
loop = tf.while_loop(
    condition,
    body,
    [a1],
)

with tf.Session() as sess:
    result = sess.run(loop)
    print(result)
    # [[6 6]
    #  [7 8]
    #  [6 6]]
    # All numbers now are greater than 6

Upvotes: 4

Related Questions