Antony Joseph
Antony Joseph

Reputation: 275

TensorFlow: Error with tf.where()

I am not sure why tf.where() does not work as planned. I want to use the values of a where yt is less that 5, otherwise use b.

tf.InteractiveSession()
yt = tf.constant([10,1,10])
a = tf.constant([1,2,3])
b = tf.constant([3,4,5])
tf.where(tf.less(yt,[5]), a, b).eval()

Gives the error

where() takes at most 2 arguments (3 given)

Can you tell me why I am getting this error? Is there any other way to do this?

Upvotes: 3

Views: 2339

Answers (1)

mrry
mrry

Reputation: 126154

The syntax for tf.where() was changed between TensorFlow 0.10 (when it took two arguments and returned two outputs) and TensorFlow 0.12+ (it now takes three tensor arguments and returns a single output, replacing the former tf.select()).

As Himaprasoon suggests, upgrading to the latest version of TensorFlow should fix your problem.

Upvotes: 4

Related Questions