Rodsnjr
Rodsnjr

Reputation: 127

Tensorflow, where (index) 'and' conditional

I have the following Tensor:

# (class, index)
obj_class_indexes = tf.constant([(0, 0), (0, 1), (0, 2), (1, 3)])

And to each value I'm looking for the objects with the same class. For now I'm trying the following:

same_classes = tf.logical_and(tf.equal(obj_classes_indexes[:, 0], obj_classes_indexes[0][0]), \
                                        obj_classes_indexes[:, 1]  > obj_classes_indexes[0][1])
found_indexes = tf.where(same_classes)

with tf.Session() as sess:
    print(sess.run(same_classes))
    print(sess.run(indexes))

The expected output would be:

[False  True  True False]
[1, 2]

But it's giving me:

[False  True  True False]
[[1], [2]]

I don't think the logical_and output is actually the correct input to the tf.where function. Or Am I missing something?

Thanks!

Upvotes: 1

Views: 1762

Answers (1)

hampi
hampi

Reputation: 145

There is nothin wrong with the output. tf.where() is expected to output a 2D tensor as quoted here: "The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements"

If you want the output to be a 1D tensor as you have mentioned, you could just add a reshape op in your case as below:

found_indexes = tf.where(same_classes)
found_indexes = tf.reshape(found_indexes, [-1])

hope this helps!

Upvotes: 2

Related Questions