Reputation: 2689
I have a highly imbalanced 3D dataset, where about 80% of the volume is background data, I am only interested in the foreground elements which constitute about 20% of the total volume at random locations. These locations are noted in the label tensor given to the network. The target tensor is binary where 0 represents the background and 1 represents the areas we are interested in or want to segment.
The size of each volume is [30,512,1024]
. I am iterating over each volume using blocks of size [30,64,64]
. Thus most of my blocks have only 0 values in the target tensor.
I read that DiceLoss
is perfect for such problems and is used successfully in segmenting 3D MRI scans. One simple implementation is from here: https://github.com/pytorch/pytorch/issues/1249#issuecomment-305088398
def dice_loss(input, target):
smooth = 1.
iflat = input.view(-1)
tflat = target.view(-1)
intersection = (iflat * tflat).sum()
return 1 - ((2. * intersection + smooth) /
(iflat.sum() + tflat.sum() + smooth))
This is not working for me, I mean for a patch where all I have are background the tflat.sum()
would be 0
. This would make intersection
0
as well, thus for majority of my patches or blocks I will get a return of 1
.
Is this right? This is not how it is supposed to work. But I am struggling with this as this is my network output:
idx: 0 of 312 - Training Loss: 1.0 - Training Accuracy: 3.204042239857152e-11
idx: 5 of 312 - Training Loss: 0.9876335859298706 - Training Accuracy: 0.0119545953348279
idx: 10 of 312 - Training Loss: 1.0 - Training Accuracy: 7.269467666715101e-11
idx: 15 of 312 - Training Loss: 0.7320756912231445 - Training Accuracy: 0.22638492286205292
idx: 20 of 312 - Training Loss: 0.3599294424057007 - Training Accuracy: 0.49074622988700867
idx: 25 of 312 - Training Loss: 1.0 - Training Accuracy: 1.0720428988975073e-09
idx: 30 of 312 - Training Loss: 1.0 - Training Accuracy: 1.19782361807097e-09
idx: 35 of 312 - Training Loss: 1.0 - Training Accuracy: 1.956790285362331e-09
idx: 40 of 312 - Training Loss: 1.0 - Training Accuracy: 1.6055999862985004e-09
idx: 45 of 312 - Training Loss: 1.0 - Training Accuracy: 7.580232552761856e-10
idx: 50 of 312 - Training Loss: 1.0 - Training Accuracy: 9.510597864803572e-10
idx: 55 of 312 - Training Loss: 1.0 - Training Accuracy: 1.341515676323013e-09
idx: 60 of 312 - Training Loss: 0.7165247797966003 - Training Accuracy: 0.02658153884112835
idx: 65 of 312 - Training Loss: 1.0 - Training Accuracy: 4.528208030762926e-09
idx: 70 of 312 - Training Loss: 0.3205708861351013 - Training Accuracy: 0.6673439145088196
idx: 75 of 312 - Training Loss: 0.9305377006530762 - Training Accuracy: 2.3437689378624782e-05
idx: 80 of 312 - Training Loss: 1.0 - Training Accuracy: 5.305786885401176e-07
idx: 85 of 312 - Training Loss: 1.0 - Training Accuracy: 4.0612556517771736e-07
idx: 90 of 312 - Training Loss: 0.8207412362098694 - Training Accuracy: 0.0344742126762867
idx: 95 of 312 - Training Loss: 0.7463213205337524 - Training Accuracy: 0.19459737837314606
idx: 100 of 312 - Training Loss: 1.0 - Training Accuracy: 4.863646818620282e-09
idx: 105 of 312 - Training Loss: 0.35790306329727173 - Training Accuracy: 0.608722984790802
idx: 110 of 312 - Training Loss: 1.0 - Training Accuracy: 3.3852198821904267e-09
idx: 115 of 312 - Training Loss: 1.0 - Training Accuracy: 1.5268487585373691e-09
idx: 120 of 312 - Training Loss: 1.0 - Training Accuracy: 3.46353523639209e-09
idx: 125 of 312 - Training Loss: 1.0 - Training Accuracy: 2.5878148582347826e-11
idx: 130 of 312 - Training Loss: 1.0 - Training Accuracy: 2.3601216467272756e-11
idx: 135 of 312 - Training Loss: 1.0 - Training Accuracy: 1.1504343033763575e-09
idx: 140 of 312 - Training Loss: 0.4516671299934387 - Training Accuracy: 0.13879922032356262
I dont think the network is learning anything from this..
Now I'm confused, as my problem should not be too complex as I am sure MRI scans have target tensors as well where majority of them signify background.. What am I doing wrong?
Thanks
Upvotes: 2
Views: 507
Reputation: 13113
You will get return of 1 if your algorithm predicts that all background voxels should have a value of exactly 0, but if it predicts any positive value (which it will surely do if you're using sigmoid activation) it can still improve the loss by outputting as little as possible. In other words, the numerator cannot go above smooth
but the algorithm can still learn to keep the denominator as small as possible.
If you're unsatisfied with your algorithm's behavior you can try to either increase your batch size (so the chance of none of the volumes having any foreground drops) or straight up skip such batches. It may or may not help learning.
That being said, I've personally never had any success learning segmentation with Dice/IoU as loss functions and generally opt for binary cross entropy or similar losses, keeping the former as validation metrics.
Upvotes: 2