Soumil Kanwal
Soumil Kanwal

Reputation: 93

Unet: Multi Class Image Segmentation

I have recently started learning about Image Segmentation and UNet. I am trying to do a multi class Image Segmentation where I have 7 classes and input is a (256, 256, 3) rgb image and output is (256, 256, 1) grayscale image where each intensity value corresponds to one class. I am doing pixel wise softmax. I am using sparse categorical cross entropy so as to avoid doing One Hot Encoding.

def soft1(x):
    return keras.activations.softmax(x, axis = -1)

def conv2d_block(input_tensor, n_filters, kernel_size = 3, batchnorm = True):

    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x

def get_unet(input_img, n_classes, n_filters = 16, dropout = 0.1, batchnorm = True):
    # Contracting Path
    c1 = conv2d_block(input_img, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
    p1 = MaxPooling2D((2, 2))(c1)
    p1 = Dropout(dropout)(p1)

    c2 = conv2d_block(p1, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
    p2 = MaxPooling2D((2, 2))(c2)
    p2 = Dropout(dropout)(p2)

    c3 = conv2d_block(p2, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
    p3 = MaxPooling2D((2, 2))(c3)
    p3 = Dropout(dropout)(p3)

    c4 = conv2d_block(p3, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
    p4 = MaxPooling2D((2, 2))(c4)
    p4 = Dropout(dropout)(p4)

    c5 = conv2d_block(p4, n_filters = n_filters * 16, kernel_size = 3, batchnorm = batchnorm)

    # Expansive Path
    u6 = Conv2DTranspose(n_filters * 8, (3, 3), strides = (2, 2), padding = 'same')(c5)
    u6 = concatenate([u6, c4])
    u6 = Dropout(dropout)(u6)
    c6 = conv2d_block(u6, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)

    u7 = Conv2DTranspose(n_filters * 4, (3, 3), strides = (2, 2), padding = 'same')(c6)
    u7 = concatenate([u7, c3])
    u7 = Dropout(dropout)(u7)
    c7 = conv2d_block(u7, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)

    u8 = Conv2DTranspose(n_filters * 2, (3, 3), strides = (2, 2), padding = 'same')(c7)
    u8 = concatenate([u8, c2])
    u8 = Dropout(dropout)(u8)
    c8 = conv2d_block(u8, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)

    u9 = Conv2DTranspose(n_filters * 1, (3, 3), strides = (2, 2), padding = 'same')(c8)
    u9 = concatenate([u9, c1])
    u9 = Dropout(dropout)(u9)
    c9 = conv2d_block(u9, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)

    outputs = Conv2D(n_classes, (1, 1))(c9)
    outputs = Reshape((image_height*image_width, 1, n_classes), input_shape = (image_height, image_width, n_classes))(outputs)
    outputs = Activation(soft1)(outputs)

    model = Model(inputs=[input_img], outputs=[outputs])

    return model

My Model Summary is:

Model: "model_2"
Layer (type)                    Output Shape         Param #     Connected to                     
input_12 (InputLayer)           (None, 256, 256, 3)  0                                            
conv2d_211 (Conv2D)             (None, 256, 256, 16) 448         input_12[0][0]                   
batch_normalization_200 (BatchN (None, 256, 256, 16) 64          conv2d_211[0][0]                 
activation_204 (Activation)     (None, 256, 256, 16) 0           batch_normalization_200[0][0]    
max_pooling2d_45 (MaxPooling2D) (None, 128, 128, 16) 0           activation_204[0][0]             
dropout_89 (Dropout)            (None, 128, 128, 16) 0           max_pooling2d_45[0][0]           
conv2d_213 (Conv2D)             (None, 128, 128, 32) 4640        dropout_89[0][0]                 
batch_normalization_202 (BatchN (None, 128, 128, 32) 128         conv2d_213[0][0]                 
activation_206 (Activation)     (None, 128, 128, 32) 0           batch_normalization_202[0][0]    
max_pooling2d_46 (MaxPooling2D) (None, 64, 64, 32)   0           activation_206[0][0]             
dropout_90 (Dropout)            (None, 64, 64, 32)   0           max_pooling2d_46[0][0]           
conv2d_215 (Conv2D)             (None, 64, 64, 64)   18496       dropout_90[0][0]                 
batch_normalization_204 (BatchN (None, 64, 64, 64)   256         conv2d_215[0][0]                 
activation_208 (Activation)     (None, 64, 64, 64)   0           batch_normalization_204[0][0]    
max_pooling2d_47 (MaxPooling2D) (None, 32, 32, 64)   0           activation_208[0][0]             
dropout_91 (Dropout)            (None, 32, 32, 64)   0           max_pooling2d_47[0][0]           
conv2d_217 (Conv2D)             (None, 32, 32, 128)  73856       dropout_91[0][0]                 
batch_normalization_206 (BatchN (None, 32, 32, 128)  512         conv2d_217[0][0]                 
activation_210 (Activation)     (None, 32, 32, 128)  0           batch_normalization_206[0][0]    
max_pooling2d_48 (MaxPooling2D) (None, 16, 16, 128)  0           activation_210[0][0]             
dropout_92 (Dropout)            (None, 16, 16, 128)  0           max_pooling2d_48[0][0]           
conv2d_219 (Conv2D)             (None, 16, 16, 256)  295168      dropout_92[0][0]                 
batch_normalization_208 (BatchN (None, 16, 16, 256)  1024        conv2d_219[0][0]                 
activation_212 (Activation)     (None, 16, 16, 256)  0           batch_normalization_208[0][0]    
conv2d_transpose_45 (Conv2DTran (None, 32, 32, 128)  295040      activation_212[0][0]             
concatenate_45 (Concatenate)    (None, 32, 32, 256)  0           conv2d_transpose_45[0][0]        
dropout_93 (Dropout)            (None, 32, 32, 256)  0           concatenate_45[0][0]             
conv2d_221 (Conv2D)             (None, 32, 32, 128)  295040      dropout_93[0][0]                 
batch_normalization_210 (BatchN (None, 32, 32, 128)  512         conv2d_221[0][0]                 
activation_214 (Activation)     (None, 32, 32, 128)  0           batch_normalization_210[0][0]    
conv2d_transpose_46 (Conv2DTran (None, 64, 64, 64)   73792       activation_214[0][0]             
concatenate_46 (Concatenate)    (None, 64, 64, 128)  0           conv2d_transpose_46[0][0]        
dropout_94 (Dropout)            (None, 64, 64, 128)  0           concatenate_46[0][0]             
conv2d_223 (Conv2D)             (None, 64, 64, 64)   73792       dropout_94[0][0]                 
batch_normalization_212 (BatchN (None, 64, 64, 64)   256         conv2d_223[0][0]                 
activation_216 (Activation)     (None, 64, 64, 64)   0           batch_normalization_212[0][0]    
conv2d_transpose_47 (Conv2DTran (None, 128, 128, 32) 18464       activation_216[0][0]             
concatenate_47 (Concatenate)    (None, 128, 128, 64) 0           conv2d_transpose_47[0][0]        
dropout_95 (Dropout)            (None, 128, 128, 64) 0           concatenate_47[0][0]             
conv2d_225 (Conv2D)             (None, 128, 128, 32) 18464       dropout_95[0][0]                 
batch_normalization_214 (BatchN (None, 128, 128, 32) 128         conv2d_225[0][0]                 
activation_218 (Activation)     (None, 128, 128, 32) 0           batch_normalization_214[0][0]    
conv2d_transpose_48 (Conv2DTran (None, 256, 256, 16) 4624        activation_218[0][0]             
concatenate_48 (Concatenate)    (None, 256, 256, 32) 0           conv2d_transpose_48[0][0]        
dropout_96 (Dropout)            (None, 256, 256, 32) 0           concatenate_48[0][0]             
conv2d_227 (Conv2D)             (None, 256, 256, 16) 4624        dropout_96[0][0]                 
batch_normalization_216 (BatchN (None, 256, 256, 16) 64          conv2d_227[0][0]                 
activation_220 (Activation)     (None, 256, 256, 16) 0           batch_normalization_216[0][0]    
conv2d_228 (Conv2D)             (None, 256, 256, 7)  119         activation_220[0][0]             
reshape_12 (Reshape)            (None, 65536, 1, 7)  0           conv2d_228[0][0]                 
activation_221 (Activation)     (None, 65536, 1, 7)  0           reshape_12[0][0]                 
Total params: 1,179,511
Trainable params: 1,178,039
Non-trainable params: 1,472

Is my model right? Shouldn't the final output be (65536, 1, 1) as I am using softmax? The code is compiling but dice coefficient is very low.

Upvotes: 4

Views: 3546

Answers (2)

Daniel Möller
Daniel Möller

Reputation: 86600

Your model should end in (256,256,7).

That is 7 classes per pixel, and the shape should agree with your output images that are (256,256,1). This will work only for 'sparse_categorical_crossentropy' or a custom loss.

So, up to conv_228 the model seems fine (didn't look in detail, though).
There is no need for anything that comes after this convolution.

You can place the softmax directly in the conv_228 or directly after.

y_train should be (256,256,1) for this.

Upvotes: 5

Chris Tosh
Chris Tosh

Reputation: 591

Your output in fact represents its pixel of your image. For its pixel, you have as an output of 1x7. Since it is sigmoid the values that this representation takes are between 0-1. Therefore the output fires when you have the desired class and therefore segmentation. If it was (65536, 1, 1) you should have not categorical but dense representation.

Upvotes: 3

Related Questions