Reputation: 93
I have recently started learning about Image Segmentation and UNet. I am trying to do a multi class Image Segmentation where I have 7 classes and input is a (256, 256, 3) rgb image and output is (256, 256, 1) grayscale image where each intensity value corresponds to one class. I am doing pixel wise softmax. I am using sparse categorical cross entropy so as to avoid doing One Hot Encoding.
def soft1(x):
return keras.activations.softmax(x, axis = -1)
def conv2d_block(input_tensor, n_filters, kernel_size = 3, batchnorm = True):
x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
if batchnorm:
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
if batchnorm:
x = BatchNormalization()(x)
x = Activation('relu')(x)
return x
def get_unet(input_img, n_classes, n_filters = 16, dropout = 0.1, batchnorm = True):
# Contracting Path
c1 = conv2d_block(input_img, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
p1 = MaxPooling2D((2, 2))(c1)
p1 = Dropout(dropout)(p1)
c2 = conv2d_block(p1, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
p2 = MaxPooling2D((2, 2))(c2)
p2 = Dropout(dropout)(p2)
c3 = conv2d_block(p2, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
p3 = MaxPooling2D((2, 2))(c3)
p3 = Dropout(dropout)(p3)
c4 = conv2d_block(p3, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
p4 = MaxPooling2D((2, 2))(c4)
p4 = Dropout(dropout)(p4)
c5 = conv2d_block(p4, n_filters = n_filters * 16, kernel_size = 3, batchnorm = batchnorm)
# Expansive Path
u6 = Conv2DTranspose(n_filters * 8, (3, 3), strides = (2, 2), padding = 'same')(c5)
u6 = concatenate([u6, c4])
u6 = Dropout(dropout)(u6)
c6 = conv2d_block(u6, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
u7 = Conv2DTranspose(n_filters * 4, (3, 3), strides = (2, 2), padding = 'same')(c6)
u7 = concatenate([u7, c3])
u7 = Dropout(dropout)(u7)
c7 = conv2d_block(u7, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
u8 = Conv2DTranspose(n_filters * 2, (3, 3), strides = (2, 2), padding = 'same')(c7)
u8 = concatenate([u8, c2])
u8 = Dropout(dropout)(u8)
c8 = conv2d_block(u8, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
u9 = Conv2DTranspose(n_filters * 1, (3, 3), strides = (2, 2), padding = 'same')(c8)
u9 = concatenate([u9, c1])
u9 = Dropout(dropout)(u9)
c9 = conv2d_block(u9, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
outputs = Conv2D(n_classes, (1, 1))(c9)
outputs = Reshape((image_height*image_width, 1, n_classes), input_shape = (image_height, image_width, n_classes))(outputs)
outputs = Activation(soft1)(outputs)
model = Model(inputs=[input_img], outputs=[outputs])
print(outputs.shape)
return model
My Model Summary is:
Model: "model_2"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_12 (InputLayer) (None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv2d_211 (Conv2D) (None, 256, 256, 16) 448 input_12[0][0]
__________________________________________________________________________________________________
batch_normalization_200 (BatchN (None, 256, 256, 16) 64 conv2d_211[0][0]
__________________________________________________________________________________________________
activation_204 (Activation) (None, 256, 256, 16) 0 batch_normalization_200[0][0]
__________________________________________________________________________________________________
max_pooling2d_45 (MaxPooling2D) (None, 128, 128, 16) 0 activation_204[0][0]
__________________________________________________________________________________________________
dropout_89 (Dropout) (None, 128, 128, 16) 0 max_pooling2d_45[0][0]
__________________________________________________________________________________________________
conv2d_213 (Conv2D) (None, 128, 128, 32) 4640 dropout_89[0][0]
__________________________________________________________________________________________________
batch_normalization_202 (BatchN (None, 128, 128, 32) 128 conv2d_213[0][0]
__________________________________________________________________________________________________
activation_206 (Activation) (None, 128, 128, 32) 0 batch_normalization_202[0][0]
__________________________________________________________________________________________________
max_pooling2d_46 (MaxPooling2D) (None, 64, 64, 32) 0 activation_206[0][0]
__________________________________________________________________________________________________
dropout_90 (Dropout) (None, 64, 64, 32) 0 max_pooling2d_46[0][0]
__________________________________________________________________________________________________
conv2d_215 (Conv2D) (None, 64, 64, 64) 18496 dropout_90[0][0]
__________________________________________________________________________________________________
batch_normalization_204 (BatchN (None, 64, 64, 64) 256 conv2d_215[0][0]
__________________________________________________________________________________________________
activation_208 (Activation) (None, 64, 64, 64) 0 batch_normalization_204[0][0]
__________________________________________________________________________________________________
max_pooling2d_47 (MaxPooling2D) (None, 32, 32, 64) 0 activation_208[0][0]
__________________________________________________________________________________________________
dropout_91 (Dropout) (None, 32, 32, 64) 0 max_pooling2d_47[0][0]
__________________________________________________________________________________________________
conv2d_217 (Conv2D) (None, 32, 32, 128) 73856 dropout_91[0][0]
__________________________________________________________________________________________________
batch_normalization_206 (BatchN (None, 32, 32, 128) 512 conv2d_217[0][0]
__________________________________________________________________________________________________
activation_210 (Activation) (None, 32, 32, 128) 0 batch_normalization_206[0][0]
__________________________________________________________________________________________________
max_pooling2d_48 (MaxPooling2D) (None, 16, 16, 128) 0 activation_210[0][0]
__________________________________________________________________________________________________
dropout_92 (Dropout) (None, 16, 16, 128) 0 max_pooling2d_48[0][0]
__________________________________________________________________________________________________
conv2d_219 (Conv2D) (None, 16, 16, 256) 295168 dropout_92[0][0]
__________________________________________________________________________________________________
batch_normalization_208 (BatchN (None, 16, 16, 256) 1024 conv2d_219[0][0]
__________________________________________________________________________________________________
activation_212 (Activation) (None, 16, 16, 256) 0 batch_normalization_208[0][0]
__________________________________________________________________________________________________
conv2d_transpose_45 (Conv2DTran (None, 32, 32, 128) 295040 activation_212[0][0]
__________________________________________________________________________________________________
concatenate_45 (Concatenate) (None, 32, 32, 256) 0 conv2d_transpose_45[0][0]
activation_210[0][0]
__________________________________________________________________________________________________
dropout_93 (Dropout) (None, 32, 32, 256) 0 concatenate_45[0][0]
__________________________________________________________________________________________________
conv2d_221 (Conv2D) (None, 32, 32, 128) 295040 dropout_93[0][0]
__________________________________________________________________________________________________
batch_normalization_210 (BatchN (None, 32, 32, 128) 512 conv2d_221[0][0]
__________________________________________________________________________________________________
activation_214 (Activation) (None, 32, 32, 128) 0 batch_normalization_210[0][0]
__________________________________________________________________________________________________
conv2d_transpose_46 (Conv2DTran (None, 64, 64, 64) 73792 activation_214[0][0]
__________________________________________________________________________________________________
concatenate_46 (Concatenate) (None, 64, 64, 128) 0 conv2d_transpose_46[0][0]
activation_208[0][0]
__________________________________________________________________________________________________
dropout_94 (Dropout) (None, 64, 64, 128) 0 concatenate_46[0][0]
__________________________________________________________________________________________________
conv2d_223 (Conv2D) (None, 64, 64, 64) 73792 dropout_94[0][0]
__________________________________________________________________________________________________
batch_normalization_212 (BatchN (None, 64, 64, 64) 256 conv2d_223[0][0]
__________________________________________________________________________________________________
activation_216 (Activation) (None, 64, 64, 64) 0 batch_normalization_212[0][0]
__________________________________________________________________________________________________
conv2d_transpose_47 (Conv2DTran (None, 128, 128, 32) 18464 activation_216[0][0]
__________________________________________________________________________________________________
concatenate_47 (Concatenate) (None, 128, 128, 64) 0 conv2d_transpose_47[0][0]
activation_206[0][0]
__________________________________________________________________________________________________
dropout_95 (Dropout) (None, 128, 128, 64) 0 concatenate_47[0][0]
__________________________________________________________________________________________________
conv2d_225 (Conv2D) (None, 128, 128, 32) 18464 dropout_95[0][0]
__________________________________________________________________________________________________
batch_normalization_214 (BatchN (None, 128, 128, 32) 128 conv2d_225[0][0]
__________________________________________________________________________________________________
activation_218 (Activation) (None, 128, 128, 32) 0 batch_normalization_214[0][0]
__________________________________________________________________________________________________
conv2d_transpose_48 (Conv2DTran (None, 256, 256, 16) 4624 activation_218[0][0]
__________________________________________________________________________________________________
concatenate_48 (Concatenate) (None, 256, 256, 32) 0 conv2d_transpose_48[0][0]
activation_204[0][0]
__________________________________________________________________________________________________
dropout_96 (Dropout) (None, 256, 256, 32) 0 concatenate_48[0][0]
__________________________________________________________________________________________________
conv2d_227 (Conv2D) (None, 256, 256, 16) 4624 dropout_96[0][0]
__________________________________________________________________________________________________
batch_normalization_216 (BatchN (None, 256, 256, 16) 64 conv2d_227[0][0]
__________________________________________________________________________________________________
activation_220 (Activation) (None, 256, 256, 16) 0 batch_normalization_216[0][0]
__________________________________________________________________________________________________
conv2d_228 (Conv2D) (None, 256, 256, 7) 119 activation_220[0][0]
__________________________________________________________________________________________________
reshape_12 (Reshape) (None, 65536, 1, 7) 0 conv2d_228[0][0]
__________________________________________________________________________________________________
activation_221 (Activation) (None, 65536, 1, 7) 0 reshape_12[0][0]
==================================================================================================
Total params: 1,179,511
Trainable params: 1,178,039
Non-trainable params: 1,472
__________________________________________________________________________________________________
Is my model right? Shouldn't the final output be (65536, 1, 1) as I am using softmax? The code is compiling but dice coefficient is very low.
Upvotes: 4
Views: 3546
Reputation: 86600
Your model should end in (256,256,7)
.
That is 7 classes per pixel, and the shape should agree with your output images that are (256,256,1)
. This will work only for 'sparse_categorical_crossentropy'
or a custom loss.
So, up to conv_228
the model seems fine (didn't look in detail, though).
There is no need for anything that comes after this convolution.
You can place the softmax directly in the conv_228
or directly after.
y_train
should be (256,256,1)
for this.
Upvotes: 5
Reputation: 591
Your output in fact represents its pixel of your image. For its pixel, you have as an output of 1x7
. Since it is sigmoid the values that this representation takes are between 0-1
. Therefore the output fires when you have the desired class and therefore segmentation. If it was (65536, 1, 1)
you should have not categorical but dense representation.
Upvotes: 3