Reputation: 41
How do i get first 3 max number from y_classe = tf.argmax(preds, axis=1, output_type=tf.int32)
?
Upvotes: 1
Views: 260
Reputation: 26708
You can use tf.math.top_k:
import tensorflow as tf
y_pred = [[-18.6, 0.51, 2.94, -12.8]]
max_entries = 3
values, indices = tf.math.top_k(y_pred, k=max_entries)
print(values)
print(indices)
tf.Tensor([[ 2.94 0.51 -12.8 ]], shape=(1, 3), dtype=float32)
tf.Tensor([[2 1 3]], shape=(1, 3), dtype=int32)
Upvotes: 1
Reputation: 319
You can sort and take the first 3:
import tensorflow as tf
a = [1, 10, 26.9, 2.8, 166.32, 62.3]
sorted_a = tf.sort(a,direction='DESCENDING')
max_3 = tf.gather(sorted_a, [0,1,2])
print(max_3)
Upvotes: 1