Xavs Quah
Xavs Quah

Reputation: 11

Tensorflow - tf.nn.weighted_cross_entropy_with_logits - logits and targets must have the same shape

I've just started using tensorflow for a project I'm working on. The program aims to be a binary classifier with input being 12 features. The output is either normal patient or patient with a disease. The prevalence of the disease is quite low and so my dataset is very imbalanced, with 502 examples of normal controls and only 38 diseased patients. For this reason, I'm trying to use tf.nn.weighted_cross_entropy_with_logits as my cost function.

The code is based on the iris custom estimator from the official tensorflow documentation, and works with tf.losses.sparse_softmax_cross_entropy as the cost function. However, when I change to weighted_cross_entropy_with_logits, I get a shape error and I'm not sure how to fix this.

ValueError: logits and targets must have the same shape ((?, 2) vs (?,))

I have searched and similar problems have been solved by just reshaping the labels - I have tried to do this unsuccessfully (and don't understand why tf.losses.sparse_softmax_cross_entropy works fine and the weighted version does not).

My full code is here https://gist.github.com/revacious/83142573700c17b8d26a4a1b84b0dff7

Thanks!

Upvotes: 1

Views: 1147

Answers (1)

javidcf
javidcf

Reputation: 59701

With non-sparse cross-entropy functions, you need to one-hot encode your labels so they have the same shape as your logits:

loss = tf.nn.weighted_cross_entropy_with_logits(tf.one_hot(labels, 2), logits, pos_weight)

Note tf.losses.sparse_softmax_cross_entropy also admits a weights parameter, although it has a slightly different meaning (it is just a sample-wise weight). The equivalent formulation should be:

loss = tf.losses.sparse_softmax_cross_entropy(labels, logits,
                                              weights=pos_weight * labels + (1 - labels))

Upvotes: 1

Related Questions