Reputation: 693
I have two classes: positive (1) and negative (0).
The data set is very unbalanced, so at the moment my mini-batches contain mostly 0s. In fact, many batches will contain 0s only. I wanted to experiment with having a separate cost for the positive and negative examples; see code below.
The problem with my code is that I'm getting many nan
's because the bound_index list will be empty. What is an elegant way to resolve this?
def calc_loss_debug(logits, labels):
logits = tf.reshape(logits, [-1])
labels = tf.reshape(labels, [-1])
index_bound = tf.where(tf.equal(labels, tf.constant(1, dtype=tf.float32)))
index_unbound = tf.where(tf.equal(labels, tf.constant(0, dtype=tf.float32)))
entropies = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
entropies_bound = tf.gather(entropies, index_bound)
entropies_unbound = tf.gather(entropies, index_unbound)
loss_bound = tf.reduce_mean(entropies_bound)
loss_unbound = tf.reduce_mean(entropies_unbound)
Upvotes: 0
Views: 316
Reputation: 1502
Since you have 0 and 1 labels, you can easily avoid tf.where
with a construction like this
labels = ...
entropies = ...
labels_complement = tf.constant(1.0, dtype=tf.float32) - labels
entropy_ones = tf.reduce_sum(tf.mul(labels, entropies))
entropy_zeros = tf.reduce_sum(tf.mul(labels_complement, entropies))
To get the mean loss, you need to divide by the number of 0s and 1s in the batch, which can be easily computed as
num_ones = tf.reduce_sum(labels)
num_zeros = tf.reduce_sum(labels_complement)
Of course, you still have to guard against dividing by 0 when there are no 1s in the batch. I would suggest using tf.cond(tf.equal(num_ones, 0), ...)
.
Upvotes: 1