jhonatanoliveira
jhonatanoliveira

Reputation: 141

An equivalent but differentiable argmax expression in Tensorflow

I need a one-hot representation for the maximum value in a tensor.

For example, consider a tensor 2 x 3:

[ [1, 5, 2],
  [0, 3, 7] ]

The one-hot-argmax representation I am aiming for looks like this:

[ [0, 1, 0],
  [0, 0, 1] ]

I can do it as follows, where my_tensor is a N x 3 tensor:

position = tf.argmax(my_tensor, axis=1).      # Shape (N x )
one_hot_pos = tf.one_hot(position, depth=3)   # Shape (N x 3)

But this part of the code need be differentiable since I'm training over it. My workaround was as follows, where EPSILON = 1e-3 is a small constant:

max_value = tf.reduce_max(my_tensor, axis=1, keepdims=True)
clip_min = max_value - EPSILON
one_hot_pos = (tf.clip_by_value(my_tensor, clip_min, max_value) - clip_min) / (max_value - clip_min)

The workaround works most of the time, but - as expected - it has some issues:

Do you know any better way of simulating the argmax followed by one_hot situation, while fixing the two mentioned issues, but using only differentiable Tensorflow functions?

Upvotes: 4

Views: 929

Answers (1)

Nwoye CID
Nwoye CID

Reputation: 944

Do some maximum, tile and multiplication operations. Like:

a = tf.Variable([ [1, 5, 2], [0, 3, 7] ]) # your tensor

m = tf.reduce_max(a, axis=1) #  [5,7]
m = tf.expand_dims(m, -1)    # [[5],[7]]
m = tf.tile(m, [1,3])        # [[5,5,5],[7,7,7]]

y = tf.cast(tf.equal(a,m), tf.float32)) # [[0,1,0],[0,0,1]]

This is a tricky multiplication operation that is differentiable.

Upvotes: 1

Related Questions