Avi Avidan
Avi Avidan

Reputation: 1106

Tensorflow: Using a value in a tensor as a parameter

I want to calculate the loss function in my DNN in a different way depending on the value of the label.

Conceptually it's something like this:

def loss(logits, labels):

    if labels[0] == 0:
        return loss_function_1(logits, labels)
    else:
        return loss_function_2(logits, labels)

Obviously this won't work because I can't do this comparison on a tensor object. I also can't use eval(), because I get an error that the network is not defined. Do I have another option?

Upvotes: 1

Views: 60

Answers (1)

keveman
keveman

Reputation: 8487

You can use the tf.cond construct for this :

tf.cond(labels[0] == 0, lambda: loss_function_1(logits, labels),
                        lambda: loss_function_2(logits, labels))

Upvotes: 1

Related Questions