mobiusinversion
mobiusinversion

Reputation: 353

Keras backend: argmax if above threshold, else -1

I would like to create a custom accuracy function that uses argmax for y_pred only if the value at argmax exceeds a threshold, else -1.

In terms of the Keras backed, it would be a modification of sparse_categorical_accuracy:

return backend.cast(
    backend.equal(
        backend.flatten(y_true),
        backend.cast(backend.argmax(y_pred, axis=-1),
                     backend.floatx())),
    backend.floatx())

So, instead of:

backend.argmax(y_pred, axis=-1)

I need a function with the pseudocode logic:

argmax_values = backend.argmax(y_pred, axis=-1)
argmax_values if y_pred[argmax_values] > threshold else -1

As a concrete example, if:

x = [[0.75, 0.25], [0.85, 0.15], [0.5, 0.5], [0.95, 0.05]]

and threshold=0.8, then the result of the desired function would be:

[-1, 0, -1, 0]

How can I achieve this using the Keras backend? My Keras version is 2.2.4, so I do not have access to the TensorFlow 2 backend.

Upvotes: 0

Views: 158

Answers (1)

Lescurel
Lescurel

Reputation: 11631

You can use K.switch to conditionally assign values from two different tensors based on a condition. Using K.switch, your desired function would be:

from keras import backend as K

def argmax_w_threshold(y_pred, threshold=0.8):
    argmax_values = K.cast(K.argmax(y_pred, axis=-1), K.floatx())
    return K.switch(
        K.max(y_pred, axis=-1) > threshold,
        argmax_values, 
        -1. * K.ones_like(argmax_values)
    )

Note that both tensor in the then and else part of the K.switch must have the same shape, hence the use of K.ones_like.

On your example:

>>> import tensorflow as tf
>>> sess = tf.InteractiveSession()
>>> x = [[0.75, 0.25], [0.85, 0.15], [0.5, 0.5], [0.95, 0.05]]
>>> sess.run(argmax_w_threshold(x))
array([-1.,  0., -1.,  0.], dtype=float32)

Upvotes: 1

Related Questions