Stackd
Stackd

Reputation: 693

TensorFlow: how to implement a per-class loss function for binary classification

I have two classes: positive (1) and negative (0).

The data set is very unbalanced, so at the moment my mini-batches contain mostly 0s. In fact, many batches will contain 0s only. I wanted to experiment with having a separate cost for the positive and negative examples; see code below.

The problem with my code is that I'm getting many nan's because the bound_index list will be empty. What is an elegant way to resolve this?

def calc_loss_debug(logits, labels):
  logits = tf.reshape(logits, [-1])
  labels = tf.reshape(labels, [-1])
  index_bound = tf.where(tf.equal(labels, tf.constant(1, dtype=tf.float32)))
  index_unbound = tf.where(tf.equal(labels, tf.constant(0, dtype=tf.float32)))
  entropies = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
  entropies_bound = tf.gather(entropies, index_bound)
  entropies_unbound = tf.gather(entropies, index_unbound)
  loss_bound = tf.reduce_mean(entropies_bound)
  loss_unbound = tf.reduce_mean(entropies_unbound)

Upvotes: 0

Views: 316

Answers (1)

lballes
lballes

Reputation: 1502

Since you have 0 and 1 labels, you can easily avoid tf.where with a construction like this

labels = ...
entropies = ...
labels_complement = tf.constant(1.0, dtype=tf.float32) - labels
entropy_ones = tf.reduce_sum(tf.mul(labels, entropies))
entropy_zeros = tf.reduce_sum(tf.mul(labels_complement, entropies))

To get the mean loss, you need to divide by the number of 0s and 1s in the batch, which can be easily computed as

num_ones = tf.reduce_sum(labels)
num_zeros = tf.reduce_sum(labels_complement)

Of course, you still have to guard against dividing by 0 when there are no 1s in the batch. I would suggest using tf.cond(tf.equal(num_ones, 0), ...).

Upvotes: 1

Related Questions