Reputation: 165
I'm doing image segmentation with U-Net like architecture on Tensorflow w/Keras but I'm new in Deep Learning.
I've got this dataset with the following set shapes:
Got some examples of these images and each channel, further down.
--> with 20% positive examples and 80% negative examples equally in each set
I run some series but for the best filter combination it made
the plot for BCE with good accuracy:
The plot for custom functions, Dice_Loss by Dice_Coeff:
And some images generated from the best model trained with test images:
The problem is when I change to dice loss and coefficient, there aren´t good predictions as we seen in the image plot and now it isn´t in the image prediction as we may see.
Why it performs so badly in the dice loss? What other function do you recommend?
My dice loss and coefficient function:
def dice_coeff(y_true, y_pred, smooth=1): intersection = K.sum(K.abs(y_true * y_pred), axis=-1) return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth) def dice_loss(y_true, y_pred): return 1-dice_coeff(y_true, y_pred)
Upvotes: 4
Views: 9706
Reputation: 1
This is because most segmentation frameworks: minimizing loss functions (such as pixel-wise cross entropy) and subsequent truncation at 0.5 or argmax, are designed for pixel-wise classification accuracy.
However, the commonly-used objective of segmentation is the Dice (or IoU) metric, and these classification-based frameworks naturally cannot ensure satisfactory results based on the Dice metric. My recent paper in JMLR (https://www.jmlr.org/papers/v24/22-0712.html) has proven that for Dice-segmentation, the correct approach should be: (1) rank the estimated pixel probabilities; (2) determine the number of segmented pixels for each image through computations, see Page 6 in the paper.
Intuitively, Dice is a metric with a certain global property. It necessitates the handling of segmentation prediction on a global scale (the decision to choose a pixel for segmentation depends on its probability, while also considering the overall distribution of other pixels.), instead of simply truncating each pixel individually and independently at 0.5 like pixel-wise classification.
See more details:
Upvotes: 0