Martin Bagge
Martin Bagge

Reputation: 395

Getting around tf.argmax which is not differentiable

I've written a custom loss function for my neural network but it can't compute any gradients. I thinks it is because I need the index of the highest value and are therefore using argmax to get this index.

As argmax is not differentiable I to get around this but I don't know how it is possible.

Can anyone help?

Upvotes: 25

Views: 14422

Answers (6)

Lostefra
Lostefra

Reputation: 360

I report here the equivalent of the solution that @Nova proposes for PyTorch:

import torch

def softargmax(x, beta=1e10):
    x = torch.tensor(x)
    x_range = torch.arange(x.shape[-1], dtype=x.dtype)
    return torch.sum(torch.nn.functional.softmax(x*beta, dim=-1) * x_range, dim=-1)

Upvotes: 1

Oskar Zdrojewski
Oskar Zdrojewski

Reputation: 79

What you're effectively doing with argmax is transforming a set without numerical order into a set with one. If you were to take a derivative from a function that calculates said indices:

  • In an unordered set any index 1 is as far away from index 2 as from index 50.
  • In an ordered set number 1 is closer to number 2 than to number 50.

This means that if you were to derivate argmax in your model then the model will prioritize indices that are close together. Let's say you have a categorical vector for labels: 'dog', 'cat','fish', 'monkey'. After you use argmax: If monkey would be the true label and 'dog' the predicted one, it would have a steeper slope than if the true label was 'cat'.

The same problems persist in local extremas (instead of labels). To an extent the same problems persist in maximum/minimum functions even though they are differentiable in tensorflow.

Whether there a programmed derivative exist is irrelevant when if said derivative would make the cost function unpredictable.

Since you're writing a custom loss function I would suggest to change the format of the true labels (from numerical to categorical) and not the predicted labels. In such case you don't need a derivative.

Upvotes: 0

Nova
Nova

Reputation: 2109

As aidan suggested, it's just a softargmax stretched to the limits by beta. We can use tf.nn.softmax to get around the numerical issues:

def softargmax(x, beta=1e10):
  x = tf.convert_to_tensor(x)
  x_range = tf.range(x.shape.as_list()[-1], dtype=x.dtype)
  return tf.reduce_sum(tf.nn.softmax(x*beta) * x_range, axis=-1)

Upvotes: 20

S. C
S. C

Reputation: 156

In the case that the value range of your input is positive and you do not need the exact index of the maximum value but it's one-hot form is enough, you can use the sign function as such:

import tensorflow as tf
import numpy as np

sess = tf.Session()
x = tf.placeholder(dtype=tf.float32, shape=(None,))

y = tf.sign(tf.reduce_max(x,axis=-1,keepdims=True)-x)
y = (y-1)*(-1)

print("I can compute the gradient", tf.gradients(y, x))

for run in range(10):
    data = np.random.random(10)
    print(data.argmax(), sess.run(y, feed_dict={x:data}))

Upvotes: 3

anj-s
anj-s

Reputation: 21

tf.argmax is not differentiable because it returns an integer index. tf.reduce_max and tf.maximum are differentiable

Upvotes: 0

user3002273
user3002273

Reputation:

If you are cool with approximates,

import tensorflow as tf
import numpy as np

sess = tf.Session()
x = tf.placeholder(dtype=tf.float32, shape=(None,))
beta = tf.placeholder(dtype=tf.float32)

# Pseudo-math for the below
# y = sum( i * exp(beta * x[i]) ) / sum( exp(beta * x[i]) )
y = tf.reduce_sum(tf.cumsum(tf.ones_like(x)) * tf.exp(beta * x) / tf.reduce_sum(tf.exp(beta * x))) - 1

print("I can compute the gradient", tf.gradients(y, x))

for run in range(10):
    data = np.random.randn(10)
    print(data.argmax(), sess.run(y, feed_dict={x:data/np.linalg.norm(data), beta:1e2}))

This is using a trick that computing the mean in low temperature environments gives to the approximate maximum of the probability space. Low temperature in this case correlates with beta being very large.

In fact, as beta approaches infinity, my algorithm will converge to the maximum (assuming the maximum is unique). Unfortunately, beta can't get too large before you have numerical errors and get NaN, but there are tricks to solve that I can go into if you care.

The output looks something like,

0 2.24459
9 9.0
8 8.0
4 4.0
4 4.0
8 8.0
9 9.0
6 6.0
9 8.99995
1 1.0

So you can see that it messes up in some spots, but often gets the right answer. Depending on your algorithm, this might be fine.

Upvotes: 14

Related Questions