Reputation: 91
tf.math.argmax returns index of maximum value in a tensor.
a = tf.constant([1,2,3])
print(a)
print(tf.math.argmax(input = a))
output:
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
<tf.Tensor: shape=(), dtype=int64, numpy=2>
I want to apply tf.math.argmax function on a list of tensors. How can I do it.
input = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(input, num_or_size_splits=2, axis=-1)
print(split_sequence)
tf.math.argmax(input = split_sequence)
output:
[<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>]
tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32)
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 1, 1])>
It is giving wrong indices -> numpy=array([1, 1, 1]
desired output:
numpy=array([[2],[2]]
Upvotes: 1
Views: 679
Reputation: 24049
You can use map
to apply any function on each value in the list
.
(It's better don't use built-in function
of python as a variable so I change input
to inp
)
import tensorflow as tf
inp = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(inp, num_or_size_splits=2, axis=-1)
print(split_sequence)
result = list(map(lambda x: [tf.math.argmax(x).numpy()] , split_sequence))
print(result)
Or by thanks @jkr, we can use List Comprehensions
too. (Which one is better, map
vs List comprehension
)
>>> [[tf.math.argmax(item).numpy()] for item in split_sequence]
[[2], [2]]
[
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>,
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>
]
[[2], [2]]
Benchmark (on colab):
import tensorflow as tf
input = tf.constant([1,2,3,4,5,6]*1_000_000)
split_sequence = tf.split(input, num_or_size_splits=20, axis=-1)
%timeit tf.math.top_k(split_sequence, k=1).indices
# 13.5 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit list(map(lambda x: [tf.math.argmax(x).numpy()] , split_sequence))
# 14 ms ± 2.39 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit [[tf.math.argmax(item).numpy()] for item in split_sequence]
# 8.77 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Upvotes: 3
Reputation: 26708
I would recommend simply using tf.math.top_k
in your case:
import tensorflow as tf
input = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(input, num_or_size_splits=2, axis=-1)
x = tf.math.top_k(split_sequence, sorted=False, k=1).indices
print(x)
tf.Tensor(
[[2]
[2]], shape=(2, 1), dtype=int32)
Afterwards, if you want a Numpy
array, just call x.numpy()
.
Update 1:
A even simpler but slower approach is to just change the axis of tf.argmax
:
input = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(input, num_or_size_splits=2, axis=-1)
tf.argmax(split_sequence, axis=-1)
# <tf.Tensor: shape=(2,), dtype=int64, numpy=array([2, 2])>
You do not need any explicit loop
or map
. See benchmarks:
import tensorflow as tf
input = tf.constant([1,2,3,4,5,6]*1_000_000)
split_sequence = tf.split(input, num_or_size_splits=20, axis=-1)
@tf.function
def top_k(split_sequence):
return tf.math.top_k(split_sequence, k=1, sorted=False).indices
@tf.function
def argmax(split_sequence):
return tf.argmax(split_sequence, axis=-1)
@tf.function
def _map(split_sequence):
return list(map(lambda x: [tf.math.argmax(x)] , split_sequence))
@tf.function
def _list(split_sequence):
return [[tf.math.argmax(item)] for item in split_sequence]
%timeit top_k(split_sequence)
# 3.5 ms ± 246 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit argmax(split_sequence)
# 16.6 ms ± 3.79 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit _map(split_sequence)
# 10.3 ms ± 929 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit _list(split_sequence)
# 10.2 ms ± 2.15 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
Upvotes: 3
Reputation: 11
<tf.Tensor: shape=(), dtype=int64, numpy=2>
You can see the output in numpy = 2 i.e, 2nd index of your constant which is value 3
Upvotes: 1