Reputation: 127
I have the following Tensor:
# (class, index)
obj_class_indexes = tf.constant([(0, 0), (0, 1), (0, 2), (1, 3)])
And to each value I'm looking for the objects with the same class. For now I'm trying the following:
same_classes = tf.logical_and(tf.equal(obj_classes_indexes[:, 0], obj_classes_indexes[0][0]), \
obj_classes_indexes[:, 1] > obj_classes_indexes[0][1])
found_indexes = tf.where(same_classes)
with tf.Session() as sess:
print(sess.run(same_classes))
print(sess.run(indexes))
The expected output would be:
[False True True False]
[1, 2]
But it's giving me:
[False True True False]
[[1], [2]]
I don't think the logical_and
output is actually the correct input to the tf.where
function. Or Am I missing something?
Thanks!
Upvotes: 1
Views: 1762
Reputation: 145
There is nothin wrong with the output. tf.where() is expected to output a 2D tensor as quoted here: "The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements"
If you want the output to be a 1D tensor as you have mentioned, you could just add a reshape op in your case as below:
found_indexes = tf.where(same_classes)
found_indexes = tf.reshape(found_indexes, [-1])
hope this helps!
Upvotes: 2