prc777
prc777

Reputation: 211

Why is my basic CNN model NOT overfitting segmentation image dataset?

I have a dataset of 2000x256x256x3 RGB images (of pink tissue containing purple/blue nuclei) and corresponding ground truth of size 200x256x256x1. The ground truth images are binary. Now, here is my model (Tensorflow version 1.x and Keras):

def createFCNSameWidthModel(is1,fn,dpth,ksze,dm):
  input_shape=is1
  filter_num=fn
  depth=dpth
  ksize=ksze
  dim=dm

  import math
  from keras import backend as K
  def gelu(x):
    constant=math.sqrt(2*math.pi)
    return 0.5*x*(1+K.tanh(constant*(x+0.044715*K.pow(x,3))))

  _input=Input(shape=(input_shape,input_shape,dim))
  batch1=BatchNormalization()(_input)
  prev=batch1
  for i in range(0,depth):
    conv=Conv2D(filters=filter_num,kernel_size=ksize,padding='same',activation=gelu)(prev)
    #maxpool=MaxPooling2D(strides=(1,1))(conv)
    #batch=BatchNormalization()(conv)
    prev=conv

  _output=Conv2D(filters=1,kernel_size=3,padding='same',activation='sigmoid')(prev)

  model=Model(inputs=_input,outputs=_output)
  model.summary()
  return model

I am using a custom activation called GeLU for hidden convolutional layers.

Model summary:

[Run:AI] [DEBUG   ] [12-01-2021 18:48:01.575] [71] [optimizers.py          :16  ] Wrapping 'Adam' Keras optimizer with GA of 4 steps

Model: "model_58"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_62 (InputLayer)        (None, 256, 256, 3)       0         
_________________________________________________________________
batch_normalization_78 (Batc (None, 256, 256, 3)       12        
_________________________________________________________________
conv2d_964 (Conv2D)          (None, 256, 256, 16)      1216      
_________________________________________________________________
conv2d_965 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_966 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_967 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_968 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_969 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_970 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_971 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_972 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_973 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_974 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_975 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_976 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_977 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_978 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_979 (Conv2D)          (None, 256, 256, 16)      6416      
_________________________________________________________________
conv2d_980 (Conv2D)          (None, 256, 256, 1)       145       
=================================================================
Total params: 97,613
Trainable params: 97,607
Non-trainable params: 6
_________________________________________________________________
Effective batch size: 16

What I'm trying to achieve: I am taking a subset of the dataset (64 images with ground truth) and trying to overfit the model to see if my model works fine.

Problem: Model is not overfitting the dataset (from now on dataset means the one with 64 images only) and the loss is plateauing at a value that is not expected if the model were to overfit/mug up the dataset.

Specifications:

  1. Optimizer = Adam(learning_rate=0.001), 0.001 was found to give fast reduction in loss (in 10 epochs).
  2. Gradient Accumulation = Adam is wrapped in runai wrapper so as to perform gradient accumulation. This is due to GPU memory constraints.
  3. Loss function = I used Dice coefficient loss, but I found that it is non convex in nature (Paper: A survey of loss functions for semantic segmentation), hence I used logcosh(dice loss). The metrics I observed are Accuracy, Dice coefficient and Jaccard Index
  4. Batch size = I found the best batch size to be 16. Note this is the effective batch size, that is the one over which gradients are updated.

Related code:

import tensorflow as tf
from keras import backend as K

def jaccard_index_iou(y_true,y_pred,smooth=1):
  intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2]) # y_pred is mXrXcX1 (axis=0,1,2,3), we want only axis 1 and 2
  union = K.sum(y_true,axis=[1,2])+K.sum(y_pred,axis=[1,2])-intersection
  iou = K.mean((intersection + smooth) / (union + smooth), axis=0)
  return iou

def dice_coef_f1(y_true, y_pred, smooth=1):
  intersection = K.sum(y_true * y_pred, axis=[1,2])
  union = K.sum(y_true,axis=[1,2]) + K.sum(y_pred, axis=[1,2])
  dice = K.mean((2. * intersection + smooth)/(union + smooth), axis=0)
  return dice

def logcoshDice(y_true,y_pred):
  dice=dice_coef_f1(y_true,y_pred)
  diceloss=1-dice
  return K.log((K.exp(diceloss)+K.exp(-diceloss))/2.0) # log of cosh of dice loss

model=createFCNSameWidthModel(is1=256,fn=16,dpth=16,ksze=5,dm=3)
import runai.ga.keras as rgk
bs=4
my_steps=4
my_optimizer=Adam(learning_rate=0.001)
my_optimizer=rgk.optimizers.Optimizer(my_optimizer,steps=my_steps)
print("Effective batch size:",my_steps*bs)
model.compile(optimizer=my_optimizer,loss=logcoshDice,metrics=['acc',dice_coef_f1,jaccard_index_iou])
  1. Dataset is present in numpy array of shape 64,256,256,3 (Image) and ground truth as 64,256,256,1 (Gt). The Image dataset is not normalized as it passes thru a BatchNormalization layer.

  2. Training:

Related code:

history=model.fit(X_data,Y_data,validation_data(X_val,Y_val),batch_size=bs,epochs=50)

Results: Loss and Dice coefficient are plateauing. This should not happen in overfitting.

graphs

What I have tried:

  1. Playing with optimizers (Adam, SGD), batch size, learning rate, epochs -> No effect
  2. Playing with learning rate scheduling, reduceLRPlateau, EarlyStopping-> No effect
  3. Playing with model width, depth, kernel size -> 16 width and depth gives the above performance. I found that increasing width or depth further reduces performance. Kernel size 5 gives best performance.
  4. Playing with different loss functions-> Dice loss, log IoU loss, logcoshloss, BCE, BCE+Dice loss, Sensitivity-Specificity loss, Tversky loss, Focal tversyky loss-> logcosh(dice loss) gave the best reduction but plateauing eventually occurs.

According to universal theorem of approximation, isn't my model deep enough to memorize this small dataset. It should at least overfit if not learn. I am now at the end of my debugging knowledge and in desperate need of help to proceed further.

I also suspect that the issue maybe that my dataset is too hard for the model to learn? But can't it overfit? So, here is an example dataset image (with ground truth on right):

dataset example

Upvotes: 3

Views: 670

Answers (1)

Gamze
Gamze

Reputation: 89

You should think about a few different approaches to deal with the problem of your CNN model not overfitting as predicted. Start by experimenting with well-known segmentation architectures, such as U-Net or SegNet link, which are made expressly to deal with segmentation tasks. Next, make adjustments to hyperparameters like kernel sizes, filter sizes, and model depth because a smaller model may overfit more readily in some cases. Furthermore, experiment with different loss functions other than Dice loss, like Focal loss or a combination of Dice loss and Binary Crossentropy, to determine which one best fits your purpose. Using augmentation techniques to increase data variability can also improve model generalization and encourage overfitting. Using finders or learning rate schedules can help avoid plateauing and improve the efficacy of training. If at all feasible, confirm your results using a larger subset of the dataset to observe whether overfitting becomes more noticeable. Inspect the quality of your data and preprocessing procedures lastly to make sure there are no problems preventing the model from learning.

Upvotes: 0

Related Questions