auroua
auroua

Reputation: 153

In tensorflow How could I get the min value index but except zero

I want to get the min value index of a tensor, but the value is not 0.

a = np.array([[0, 3, 9, 0],
            [0, 0, 5, 7]])
tensor_a = tf.constant(a, dtype=tf.int32)
max_index = tf.argmax(tensor_a, axis=1)

The above code defined a constant tensor, if I use tf.argmax, I will get the index [2, 3]. How could I get the index of 3 in row one and 5 in row two, the min value but not zero. The true indexes I want to get is [1, 2]. How to implement it in tensorflow, Thanks.

Upvotes: 2

Views: 1817

Answers (1)

Joe Comer
Joe Comer

Reputation: 23

It's hideous, but it works:

with tf.Session() as sess:
    a = np.array([[0, 3, 9, 0],
                [0, 0, 5, 7]])
    tensor_a = tf.constant(a, dtype=tf.int64)
    row_max = tf.reshape(tf.reduce_max(a, axis=-1), [-1, 1]) + 1
    max_index = tf.argmin(tf.where(tensor_a > 0, tensor_a, row_max * tf.ones_like(tensor_a)), axis=1)
    print(max_index.eval()) # -> [1 2]

Upvotes: 2

Related Questions