Zara
Zara

Reputation: 105

Trouble with loss function tf.nn.weighted_cross_entropy_with_logits

I am trying to train a u-net network with binary targets. The usual Binary Cross Entropy loss does not perform well, since the lables are very imbalanced (many more 0 pixels than 1s). So I want to punish false negatives more. But tensorflow doesn't have a ready-made weighted binary cross entropy. Since I didn't want to write a loss from scratch, I'm trying to use tf.nn.weighted_cross_entropy_with_logits. to be able to easily feed the loss to model.compile function, I'm writing this wrapper:

def loss_wrapper(y,x):
   x = tf.cast(x,'float32')
   loss = tf.nn.weighted_cross_entropy_with_logits(y,x,pos_weight=10)
   return loss

However, regardless of casting x to float, I'm still getting the error:

TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type int32 of argument 'x'.

when the tf loss is called. Can someone explain what's happening?

Upvotes: 1

Views: 128

Answers (1)

AloneTogether
AloneTogether

Reputation: 26708

If x represents your predictions. It probably already has the type float32. I think you need to cast y, which is presumably your labels. So:

loss = tf.nn.weighted_cross_entropy_with_logits(tf.cast(y, dtype=tf.float32),x,pos_weight=10)

Upvotes: 1

Related Questions