Reputation: 1106
I want to calculate the loss function in my DNN in a different way depending on the value of the label.
Conceptually it's something like this:
def loss(logits, labels):
if labels[0] == 0:
return loss_function_1(logits, labels)
else:
return loss_function_2(logits, labels)
Obviously this won't work because I can't do this comparison on a tensor object. I also can't use eval()
, because I get an error that the network is not defined. Do I have another option?
Upvotes: 1
Views: 60