Reputation: 353
I would like to create a custom accuracy function that uses argmax
for y_pred
only if the value at argmax
exceeds a threshold, else -1.
In terms of the Keras backed, it would be a modification of sparse_categorical_accuracy
:
return backend.cast(
backend.equal(
backend.flatten(y_true),
backend.cast(backend.argmax(y_pred, axis=-1),
backend.floatx())),
backend.floatx())
So, instead of:
backend.argmax(y_pred, axis=-1)
I need a function with the pseudocode logic:
argmax_values = backend.argmax(y_pred, axis=-1)
argmax_values if y_pred[argmax_values] > threshold else -1
As a concrete example, if:
x = [[0.75, 0.25], [0.85, 0.15], [0.5, 0.5], [0.95, 0.05]]
and threshold=0.8
, then the result of the desired function would be:
[-1, 0, -1, 0]
How can I achieve this using the Keras backend? My Keras version is 2.2.4
, so I do not have access to the TensorFlow 2 backend.
Upvotes: 0
Views: 158
Reputation: 11631
You can use K.switch
to conditionally assign values from two different tensors based on a condition. Using K.switch
, your desired function would be:
from keras import backend as K
def argmax_w_threshold(y_pred, threshold=0.8):
argmax_values = K.cast(K.argmax(y_pred, axis=-1), K.floatx())
return K.switch(
K.max(y_pred, axis=-1) > threshold,
argmax_values,
-1. * K.ones_like(argmax_values)
)
Note that both tensor in the then
and else
part of the K.switch
must have the same shape, hence the use of K.ones_like
.
On your example:
>>> import tensorflow as tf
>>> sess = tf.InteractiveSession()
>>> x = [[0.75, 0.25], [0.85, 0.15], [0.5, 0.5], [0.95, 0.05]]
>>> sess.run(argmax_w_threshold(x))
array([-1., 0., -1., 0.], dtype=float32)
Upvotes: 1