Reputation: 141
I need a one-hot representation for the maximum value in a tensor.
For example, consider a tensor 2 x 3
:
[ [1, 5, 2],
[0, 3, 7] ]
The one-hot-argmax representation I am aiming for looks like this:
[ [0, 1, 0],
[0, 0, 1] ]
I can do it as follows, where my_tensor
is a N x 3
tensor:
position = tf.argmax(my_tensor, axis=1). # Shape (N x )
one_hot_pos = tf.one_hot(position, depth=3) # Shape (N x 3)
But this part of the code need be differentiable since I'm training over it.
My workaround was as follows, where EPSILON = 1e-3
is a small constant:
max_value = tf.reduce_max(my_tensor, axis=1, keepdims=True)
clip_min = max_value - EPSILON
one_hot_pos = (tf.clip_by_value(my_tensor, clip_min, max_value) - clip_min) / (max_value - clip_min)
The workaround works most of the time, but - as expected - it has some issues:
EPSILON
: if it is too small, a division by zero might happenargmax
only chooses one even in a tie situationDo you know any better way of simulating the argmax
followed by one_hot
situation, while fixing the two mentioned issues, but using only differentiable Tensorflow functions?
Upvotes: 4
Views: 929
Reputation: 944
Do some maximum, tile and multiplication operations. Like:
a = tf.Variable([ [1, 5, 2], [0, 3, 7] ]) # your tensor
m = tf.reduce_max(a, axis=1) # [5,7]
m = tf.expand_dims(m, -1) # [[5],[7]]
m = tf.tile(m, [1,3]) # [[5,5,5],[7,7,7]]
y = tf.cast(tf.equal(a,m), tf.float32)) # [[0,1,0],[0,0,1]]
This is a tricky multiplication operation that is differentiable.
Upvotes: 1