Reputation: 53
I need to implement a new loss function for my deep network which is the following:
import tensorflow as tf
from tensorflow.python import confusion_matrix
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import array_ops
def gms_loss(targets=None, logits=None, name=None):
#Shape checking
try:
targets.get_shape().merge_with(logits.get_shape())
except ValueError:
raise ValueError("logits and targets must have the same shape (%s vs %s)"
% (logits.get_shape(), targets.get_shape()))
#Compute the confusion matrix
predictions=tf.nn.softmax(logits)
cm=confusion_matrix(tf.argmax(targets,1),tf.argmax(predictions,1),3)
def compute_sensitivities(name):
"""Compute the sensitivity per class via the confusion matrix."""
per_row_sum = math_ops.to_float(math_ops.reduce_sum(cm, 1))
cm_diag = math_ops.to_float(array_ops.diag_part(cm))
denominator = per_row_sum
# If the value of the denominator is 0, set it to 1 to avoid
# zero division.
denominator = array_ops.where(
math_ops.greater(denominator, 0), denominator,
array_ops.ones_like(denominator))
accuracies = math_ops.div(cm_diag, denominator)
return accuracies
gms = math_ops.reduce_prod(compute_sensitivities('sensitivities'))
return gms
Here is the calling from the graph code:
test=gms_loss(targets=y,logits=pred)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(test)
and finally, the already known error:
"ValueError: No gradients provided for any variable, check your graph for ops that do not support gradients, between variables..."
I'm not able to find the problem, if I use softmax_cross_entropy, it works (but is not able to optimize properly, thats why I need the new loss function)
Thank you in advance
Upvotes: 0
Views: 519
Reputation: 2629
I think the problem is that the tf.argmax()
function is not differentiable. Therefore, the optimizer will fail to calculate the gradient of the loss function with respect to your predictions and targets. I don't know of a way to handle this with the argmax function, so I would recommend to avoid non-differentiable functions.
Upvotes: 2