pepe
pepe

Reputation: 9919

U-net: how to improve accuracy of multiclass segmentation?

I have been using U-nets for a while now, and notice that in most of my applications, it generates an over-estimation surrounding a specific class.

For example, here's a grayscale image:

enter image description here

And a manual segmentation of 3 classes (lesion [green], tissue [magenta], background [all else]):

enter image description here

The issue I notice on prediction (over-estimation at boundaries):

enter image description here

The typical architecture used looks something like this:

def get_unet(dim=128, dropout=0.5, n_classes=3):

 inputs = Input((dim, dim, 1))
 conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
 conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
 pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

 conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
 conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
 pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

 conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
 conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
 pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

 conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
 conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
 conv4 = Dropout(dropout)(conv4)
 pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

 conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
 conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
 conv5 = Dropout(dropout)(conv5)

 up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=3)
 conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up6)
 conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

 up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=3)
 conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up7)
 conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)

 up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=3)
 conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up8)
 conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)

 up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=3)
 conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up9)
 conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)

 conv10 = Conv2D(n_classes, (1, 1), activation='relu', padding='same', ker nel_initializer='he_normal')(conv9)
 conv10 = Reshape((dim * dim, n_classes))(conv10)

 output = Activation('softmax')(conv10)

 model = Model(inputs=[inputs], outputs=[output])

 return model

Plus:

mgpu_model.compile(optimizer='adadelta', loss='categorical_crossentropy',
                   metrics=['accuracy'], sample_weight_mode='temporal')  

open(p, 'w').write(json_string)

model_checkpoint = callbacks.ModelCheckpoint(f, save_best_only=True)
reduce_lr_cback = callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.2,
    patience=5, verbose=1,
    min_lr=0.05 * 0.0001)

h = mgpu_model.fit(train_gray, train_masks,
                   batch_size=64, epochs=50,
                   verbose=1, shuffle=True, validation_split=0.2, sample_weight=sample_weights,
                   callbacks=[model_checkpoint, reduce_lr_cback])

My Question: Do you have any insight or suggestion on how to change either the architecture or hyperparameters to mitigate the over-estimation? This could include even using a different architecture that may be better at more precise segmentation. (Please note I already do class balancing/weighting to compensate for imbalances in class frequency)

Upvotes: 0

Views: 3344

Answers (1)

jkr
jkr

Reputation: 19310

You can experiment with various loss functions instead of cross entropy. For multi-class segmentation, you can try:

The winner of brats 2018 used autoencoder regularization (https://github.com/IAmSuyogJadhav/3d-mri-brain-tumor-segmentation-using-autoencoder-regularization). You could try this as well. The idea in that paper is that the model is also learning how to better encode the features in the latent space, and that helps the model with segmentation somehow.

Upvotes: 1

Related Questions