user4028648
user4028648

Reputation:

TensorFlow: What is wrong with my (generalized) dice loss implementation?

I use TensorFlow 1.12 for semantic (image) segmentation based on materials. With a multinomial cross-entropy loss function, this yields okay-ish results, especially considering the sparse amount of training data I´m working with, with mIoU of 0.44:

Segmentations after training with cross entropy loss

When I replace this with my dice loss implementation, however, the networks predicts way less smaller segmentations, which is contrary to my understanding of its theory. I thought it´s supposed to work better with imbalanced datasets and should be better at predicting the smaller classes:

Segmentations after training with dice loss

A table visualizes this better; as you can see, with dice loss a lot more smaller classes are never predicted (hence the undefined precision). With cross-entropy, at least some predictions are made for all classes:

Table comparing metrics for the different losses

I initially thought that this is the networks way of increasing mIoU (since my understanding is that dice loss optimizes dice loss directly). However, mIoU with dice loss is 0.33 compared to cross entropy´s 0.44 mIoU, so it has failed in that regard. I´m now wondering whether my implementation is correct:

def dice_loss(onehots_true, logits):
    probabilities = tf.nn.softmax(logits)
    #weights = 1.0 / ((tf.reduce_sum(onehots_true, axis=0)**2) + 1e-3)
    #weights = tf.clip_by_value(weights, 1e-17, 1.0 - 1e-7)
    numerator = tf.reduce_sum(onehots_true * probabilities, axis=0)
    #numerator = tf.reduce_sum(weights * numerator)
    denominator = tf.reduce_sum(onehots_true + probabilities, axis=0)
    #denominator = tf.reduce_sum(weights * denominator)
    loss = 1.0 - 2.0 * (numerator + 1) / (denominator + 1)
    return loss

Some implementations I found use weights, though I am not sure why, since mIoU isn´t weighted either. At any rate, training is prematurely stopped after one a few epochs with dreadful test results when I use weights, hence I commented them out.

Does anyone see anything wrong with my dice loss implementation? I pretty faithfully followed online examples.

In order to speed up the labeling process, I only annotated with parallelogram shaped polygons, and I copied some annotations from a larger dataset. This resulted in only a couple of ground truth segmentations per image:

Ground truth annotations

(This image actually contains slightly more annotations than average.)

Upvotes: 3

Views: 8404

Answers (3)

Amine Sehaba
Amine Sehaba

Reputation: 120

Here an implementation from Github that would maybe help

def dice_coef(y_true, y_pred, smooth=1):
    """
    Dice = (2*|X & Y|)/ (|X|+ |Y|)
         =  2*sum(|A*B|)/(sum(A^2)+sum(B^2))
    ref: https://arxiv.org/pdf/1606.04797v1.pdf
    """
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

Upvotes: 0

Rakshit Kothari
Rakshit Kothari

Reputation: 416

I'm going to add the formula for reference to anyone who answers in the future. The generalized dice loss is given by: enter image description here

picture taken from Sudre et al.

Class is iterated by l. Each pixel location is iterated by n. The probabilities p_ln can be generated using softmax or sigmoid in your network output.


In your implementation, the loss is summed across the batch. That would produce a very large loss value and your network gradients would explode. Instead, you need to use the average. Note that the weights are required to ensure you combat the class imbalance problem.

There is no concrete proof that GDL outperforms cross-entropy, save in a very specific example noted in the paper. GDL is attractive because it is directly related to IoU, ergo the loss function and evaluation metrics would improve hand-in-hand. If you still haven't managed to train your network, I'd recommend moving to cross-entropy for good.

Upvotes: 2

javidcf
javidcf

Reputation: 59731

According to this article, your implementation of the Dice loss is incorrect. Instead of:

loss = 1.0 - 2.0 * (numerator + 1) / (denominator + 1)

You should have:

loss = 1.0 - (2.0 * numerator + 1) / (denominator + 1)

Upvotes: -1

Related Questions