Soumil Kanwal
Soumil Kanwal

Reputation: 93

Unet: Multi Class Image Segmentation

I have recently started learning about Image Segmentation and UNet. I am trying to do a multi class Image Segmentation where I have 7 classes and input is a (256, 256, 3) rgb image and output is (256, 256, 1) grayscale image where each intensity value corresponds to one class. I am doing pixel wise softmax. I am using sparse categorical cross entropy so as to avoid doing One Hot Encoding.

def soft1(x):
    return keras.activations.softmax(x, axis = -1)

def conv2d_block(input_tensor, n_filters, kernel_size = 3, batchnorm = True):

    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)


    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x

def get_unet(input_img, n_classes, n_filters = 16, dropout = 0.1, batchnorm = True):
    # Contracting Path
    c1 = conv2d_block(input_img, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
    p1 = MaxPooling2D((2, 2))(c1)
    p1 = Dropout(dropout)(p1)

    c2 = conv2d_block(p1, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
    p2 = MaxPooling2D((2, 2))(c2)
    p2 = Dropout(dropout)(p2)

    c3 = conv2d_block(p2, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
    p3 = MaxPooling2D((2, 2))(c3)
    p3 = Dropout(dropout)(p3)

    c4 = conv2d_block(p3, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
    p4 = MaxPooling2D((2, 2))(c4)
    p4 = Dropout(dropout)(p4)

    c5 = conv2d_block(p4, n_filters = n_filters * 16, kernel_size = 3, batchnorm = batchnorm)

    # Expansive Path
    u6 = Conv2DTranspose(n_filters * 8, (3, 3), strides = (2, 2), padding = 'same')(c5)
    u6 = concatenate([u6, c4])
    u6 = Dropout(dropout)(u6)
    c6 = conv2d_block(u6, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)

    u7 = Conv2DTranspose(n_filters * 4, (3, 3), strides = (2, 2), padding = 'same')(c6)
    u7 = concatenate([u7, c3])
    u7 = Dropout(dropout)(u7)
    c7 = conv2d_block(u7, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)

    u8 = Conv2DTranspose(n_filters * 2, (3, 3), strides = (2, 2), padding = 'same')(c7)
    u8 = concatenate([u8, c2])
    u8 = Dropout(dropout)(u8)
    c8 = conv2d_block(u8, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)

    u9 = Conv2DTranspose(n_filters * 1, (3, 3), strides = (2, 2), padding = 'same')(c8)
    u9 = concatenate([u9, c1])
    u9 = Dropout(dropout)(u9)
    c9 = conv2d_block(u9, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)

    outputs = Conv2D(n_classes, (1, 1))(c9)
    outputs = Reshape((image_height*image_width, 1, n_classes), input_shape = (image_height, image_width, n_classes))(outputs)
    outputs = Activation(soft1)(outputs)

    model = Model(inputs=[input_img], outputs=[outputs])
    print(outputs.shape)

    return model

My Model Summary is:

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_12 (InputLayer)           (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_211 (Conv2D)             (None, 256, 256, 16) 448         input_12[0][0]                   
__________________________________________________________________________________________________
batch_normalization_200 (BatchN (None, 256, 256, 16) 64          conv2d_211[0][0]                 
__________________________________________________________________________________________________
activation_204 (Activation)     (None, 256, 256, 16) 0           batch_normalization_200[0][0]    
__________________________________________________________________________________________________
max_pooling2d_45 (MaxPooling2D) (None, 128, 128, 16) 0           activation_204[0][0]             
__________________________________________________________________________________________________
dropout_89 (Dropout)            (None, 128, 128, 16) 0           max_pooling2d_45[0][0]           
__________________________________________________________________________________________________
conv2d_213 (Conv2D)             (None, 128, 128, 32) 4640        dropout_89[0][0]                 
__________________________________________________________________________________________________
batch_normalization_202 (BatchN (None, 128, 128, 32) 128         conv2d_213[0][0]                 
__________________________________________________________________________________________________
activation_206 (Activation)     (None, 128, 128, 32) 0           batch_normalization_202[0][0]    
__________________________________________________________________________________________________
max_pooling2d_46 (MaxPooling2D) (None, 64, 64, 32)   0           activation_206[0][0]             
__________________________________________________________________________________________________
dropout_90 (Dropout)            (None, 64, 64, 32)   0           max_pooling2d_46[0][0]           
__________________________________________________________________________________________________
conv2d_215 (Conv2D)             (None, 64, 64, 64)   18496       dropout_90[0][0]                 
__________________________________________________________________________________________________
batch_normalization_204 (BatchN (None, 64, 64, 64)   256         conv2d_215[0][0]                 
__________________________________________________________________________________________________
activation_208 (Activation)     (None, 64, 64, 64)   0           batch_normalization_204[0][0]    
__________________________________________________________________________________________________
max_pooling2d_47 (MaxPooling2D) (None, 32, 32, 64)   0           activation_208[0][0]             
__________________________________________________________________________________________________
dropout_91 (Dropout)            (None, 32, 32, 64)   0           max_pooling2d_47[0][0]           
__________________________________________________________________________________________________
conv2d_217 (Conv2D)             (None, 32, 32, 128)  73856       dropout_91[0][0]                 
__________________________________________________________________________________________________
batch_normalization_206 (BatchN (None, 32, 32, 128)  512         conv2d_217[0][0]                 
__________________________________________________________________________________________________
activation_210 (Activation)     (None, 32, 32, 128)  0           batch_normalization_206[0][0]    
__________________________________________________________________________________________________
max_pooling2d_48 (MaxPooling2D) (None, 16, 16, 128)  0           activation_210[0][0]             
__________________________________________________________________________________________________
dropout_92 (Dropout)            (None, 16, 16, 128)  0           max_pooling2d_48[0][0]           
__________________________________________________________________________________________________
conv2d_219 (Conv2D)             (None, 16, 16, 256)  295168      dropout_92[0][0]                 
__________________________________________________________________________________________________
batch_normalization_208 (BatchN (None, 16, 16, 256)  1024        conv2d_219[0][0]                 
__________________________________________________________________________________________________
activation_212 (Activation)     (None, 16, 16, 256)  0           batch_normalization_208[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_45 (Conv2DTran (None, 32, 32, 128)  295040      activation_212[0][0]             
__________________________________________________________________________________________________
concatenate_45 (Concatenate)    (None, 32, 32, 256)  0           conv2d_transpose_45[0][0]        
                                                                 activation_210[0][0]             
__________________________________________________________________________________________________
dropout_93 (Dropout)            (None, 32, 32, 256)  0           concatenate_45[0][0]             
__________________________________________________________________________________________________
conv2d_221 (Conv2D)             (None, 32, 32, 128)  295040      dropout_93[0][0]                 
__________________________________________________________________________________________________
batch_normalization_210 (BatchN (None, 32, 32, 128)  512         conv2d_221[0][0]                 
__________________________________________________________________________________________________
activation_214 (Activation)     (None, 32, 32, 128)  0           batch_normalization_210[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_46 (Conv2DTran (None, 64, 64, 64)   73792       activation_214[0][0]             
__________________________________________________________________________________________________
concatenate_46 (Concatenate)    (None, 64, 64, 128)  0           conv2d_transpose_46[0][0]        
                                                                 activation_208[0][0]             
__________________________________________________________________________________________________
dropout_94 (Dropout)            (None, 64, 64, 128)  0           concatenate_46[0][0]             
__________________________________________________________________________________________________
conv2d_223 (Conv2D)             (None, 64, 64, 64)   73792       dropout_94[0][0]                 
__________________________________________________________________________________________________
batch_normalization_212 (BatchN (None, 64, 64, 64)   256         conv2d_223[0][0]                 
__________________________________________________________________________________________________
activation_216 (Activation)     (None, 64, 64, 64)   0           batch_normalization_212[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_47 (Conv2DTran (None, 128, 128, 32) 18464       activation_216[0][0]             
__________________________________________________________________________________________________
concatenate_47 (Concatenate)    (None, 128, 128, 64) 0           conv2d_transpose_47[0][0]        
                                                                 activation_206[0][0]             
__________________________________________________________________________________________________
dropout_95 (Dropout)            (None, 128, 128, 64) 0           concatenate_47[0][0]             
__________________________________________________________________________________________________
conv2d_225 (Conv2D)             (None, 128, 128, 32) 18464       dropout_95[0][0]                 
__________________________________________________________________________________________________
batch_normalization_214 (BatchN (None, 128, 128, 32) 128         conv2d_225[0][0]                 
__________________________________________________________________________________________________
activation_218 (Activation)     (None, 128, 128, 32) 0           batch_normalization_214[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_48 (Conv2DTran (None, 256, 256, 16) 4624        activation_218[0][0]             
__________________________________________________________________________________________________
concatenate_48 (Concatenate)    (None, 256, 256, 32) 0           conv2d_transpose_48[0][0]        
                                                                 activation_204[0][0]             
__________________________________________________________________________________________________
dropout_96 (Dropout)            (None, 256, 256, 32) 0           concatenate_48[0][0]             
__________________________________________________________________________________________________
conv2d_227 (Conv2D)             (None, 256, 256, 16) 4624        dropout_96[0][0]                 
__________________________________________________________________________________________________
batch_normalization_216 (BatchN (None, 256, 256, 16) 64          conv2d_227[0][0]                 
__________________________________________________________________________________________________
activation_220 (Activation)     (None, 256, 256, 16) 0           batch_normalization_216[0][0]    
__________________________________________________________________________________________________
conv2d_228 (Conv2D)             (None, 256, 256, 7)  119         activation_220[0][0]             
__________________________________________________________________________________________________
reshape_12 (Reshape)            (None, 65536, 1, 7)  0           conv2d_228[0][0]                 
__________________________________________________________________________________________________
activation_221 (Activation)     (None, 65536, 1, 7)  0           reshape_12[0][0]                 
==================================================================================================
Total params: 1,179,511
Trainable params: 1,178,039
Non-trainable params: 1,472
__________________________________________________________________________________________________

Is my model right? Shouldn't the final output be (65536, 1, 1) as I am using softmax? The code is compiling but dice coefficient is very low.

Upvotes: 4

Views: 3546

Answers (2)

Daniel Möller
Daniel Möller

Reputation: 86600

Your model should end in (256,256,7).

That is 7 classes per pixel, and the shape should agree with your output images that are (256,256,1). This will work only for 'sparse_categorical_crossentropy' or a custom loss.

So, up to conv_228 the model seems fine (didn't look in detail, though).
There is no need for anything that comes after this convolution.

You can place the softmax directly in the conv_228 or directly after.

y_train should be (256,256,1) for this.

Upvotes: 5

Chris Tosh
Chris Tosh

Reputation: 591

Your output in fact represents its pixel of your image. For its pixel, you have as an output of 1x7. Since it is sigmoid the values that this representation takes are between 0-1. Therefore the output fires when you have the desired class and therefore segmentation. If it was (65536, 1, 1) you should have not categorical but dense representation.

Upvotes: 3

Related Questions