Reputation: 153
I want to get the min value index of a tensor, but the value is not 0.
a = np.array([[0, 3, 9, 0],
[0, 0, 5, 7]])
tensor_a = tf.constant(a, dtype=tf.int32)
max_index = tf.argmax(tensor_a, axis=1)
The above code defined a constant tensor, if I use tf.argmax, I will get the index [2, 3]. How could I get the index of 3 in row one and 5 in row two, the min value but not zero. The true indexes I want to get is [1, 2]. How to implement it in tensorflow, Thanks.
Upvotes: 2
Views: 1817
Reputation: 23
It's hideous, but it works:
with tf.Session() as sess:
a = np.array([[0, 3, 9, 0],
[0, 0, 5, 7]])
tensor_a = tf.constant(a, dtype=tf.int64)
row_max = tf.reshape(tf.reduce_max(a, axis=-1), [-1, 1]) + 1
max_index = tf.argmin(tf.where(tensor_a > 0, tensor_a, row_max * tf.ones_like(tensor_a)), axis=1)
print(max_index.eval()) # -> [1 2]
Upvotes: 2