Reputation: 9919
I have been using U-nets for a while now, and notice that in most of my applications, it generates an over-estimation surrounding a specific class.
For example, here's a grayscale image:
And a manual segmentation of 3 classes (lesion [green], tissue [magenta], background [all else]):
The issue I notice on prediction (over-estimation at boundaries):
The typical architecture used looks something like this:
def get_unet(dim=128, dropout=0.5, n_classes=3):
inputs = Input((dim, dim, 1))
conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = Dropout(dropout)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = Dropout(dropout)(conv5)
up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=3)
conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up6)
conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=3)
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up7)
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=3)
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up8)
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)
up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=3)
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up9)
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
conv10 = Conv2D(n_classes, (1, 1), activation='relu', padding='same', ker nel_initializer='he_normal')(conv9)
conv10 = Reshape((dim * dim, n_classes))(conv10)
output = Activation('softmax')(conv10)
model = Model(inputs=[inputs], outputs=[output])
return model
Plus:
mgpu_model.compile(optimizer='adadelta', loss='categorical_crossentropy',
metrics=['accuracy'], sample_weight_mode='temporal')
open(p, 'w').write(json_string)
model_checkpoint = callbacks.ModelCheckpoint(f, save_best_only=True)
reduce_lr_cback = callbacks.ReduceLROnPlateau(
monitor='val_loss', factor=0.2,
patience=5, verbose=1,
min_lr=0.05 * 0.0001)
h = mgpu_model.fit(train_gray, train_masks,
batch_size=64, epochs=50,
verbose=1, shuffle=True, validation_split=0.2, sample_weight=sample_weights,
callbacks=[model_checkpoint, reduce_lr_cback])
My Question: Do you have any insight or suggestion on how to change either the architecture or hyperparameters to mitigate the over-estimation? This could include even using a different architecture that may be better at more precise segmentation. (Please note I already do class balancing/weighting to compensate for imbalances in class frequency)
Upvotes: 0
Views: 3344
Reputation: 19310
You can experiment with various loss functions instead of cross entropy. For multi-class segmentation, you can try:
The winner of brats 2018 used autoencoder regularization (https://github.com/IAmSuyogJadhav/3d-mri-brain-tumor-segmentation-using-autoencoder-regularization). You could try this as well. The idea in that paper is that the model is also learning how to better encode the features in the latent space, and that helps the model with segmentation somehow.
Upvotes: 1