EJ Song
EJ Song

Reputation: 113

Which deep learning network structure should I refer?

I am making a deep learning network that finds several points in 3d space.

The input is a stack of grayscale 1024 x 1024 images(# of images varies 5 to 20 ), and the output is 64 x 64 x 64 space. Each voxel of output has 0 or 1, but in my dataset there are only 2000 1s, so it is hard to tell whether my network is being trained well by observing the training losses.

For example if my network only spit out np.zeros((64,64,64)) as output, the accuracy still would be 1-2000/(64x64x64)~=99.9%.

So I want to ask which deep learning network I should choose for finding very small number of answers from 3d space. The input size becomes (1024 x 1024 x #img) and output size (64 x 64 x 64). I am now making experiments using 2D Unet-like net / 3D Unet-like net, with ReLU-with-ceiling end activation.

Please somebody recommend anything to refer and thank you very much.

Upvotes: 0

Views: 72

Answers (1)

Lue Mar
Lue Mar

Reputation: 472

Unet-like networks seems to be a good idea. Your problem does not comes frop the network itself, but from the loss and metrics you are using. Indead, if you use a binary crossentropy loss and accuracy for metrics, because of the imbalanced character of your classes, your score will still be near 100%.

I suggest that you use Dice or Jaccard coefficient for metrics and/or loss (in this case loss is 1-Dicecoef), and that you calculate it only on the items of interest, and not on the background.

Depending on the framework you are using, you should easily find an existing implementation of these metrics. Then modify the code to avoid calculation on the background.

For example for python/tensorflow, using your volumes:

def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    y_true_f = K.one_hot(K.cast(y_true_f, np.uint8), 2)
    y_pred_f = K.one_hot(K.cast(y_pred_f, np.uint8), 2)
    intersection = K.sum(y_true_f[:,1:]* y_pred_f[:,1:], axis=[-1])
    union = K.sum(y_true_f[:,1:], axis=[-1]) + K.sum(y_pred_f[:,1:], axis=[-1])
    dice = K.mean((2. * intersection + smooth)/(union + smooth), axis=0)
    return dice

Upvotes: 2

Related Questions