
Reputation: 879

Keras U-Net weighted loss implementation

I'm trying to separate close objects as was shown in the U-Net paper (here). For this, one generates weight maps which can be used for pixel-wise losses. The following code describes the network I use from this blog post.

x_train_val = # list of images (imgs, 256, 256, 3)
y_train_val = # list of masks (imgs, 256, 256, 1)
y_weights = # list of weight maps (imgs, 256, 256, 1) according to the blog post 
# visual inspection confirms the correct calculation of these maps

# Blog posts' loss function
def my_loss(target, output):
    return - tf.reduce_sum(target * output,
                           len(output.get_shape()) - 1)

# Standard Unet model from blog post
_epsilon = tf.convert_to_tensor(K.epsilon(), np.float32)

def make_weighted_loss_unet(input_shape, n_classes):
    ip = L.Input(shape=input_shape)
    weight_ip = L.Input(shape=input_shape[:2] + (n_classes,))

    conv1 = L.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(ip)
    conv1 = L.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    conv1 = L.Dropout(0.1)(conv1)
    mpool1 = L.MaxPool2D()(conv1)

    conv2 = L.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(mpool1)
    conv2 = L.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    conv2 = L.Dropout(0.2)(conv2)
    mpool2 = L.MaxPool2D()(conv2)

    conv3 = L.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(mpool2)
    conv3 = L.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    conv3 = L.Dropout(0.3)(conv3)
    mpool3 = L.MaxPool2D()(conv3)

    conv4 = L.Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(mpool3)
    conv4 = L.Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    conv4 = L.Dropout(0.4)(conv4)
    mpool4 = L.MaxPool2D()(conv4)

    conv5 = L.Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(mpool4)
    conv5 = L.Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
    conv5 = L.Dropout(0.5)(conv5)

    up6 = L.Conv2DTranspose(512, 2, strides=2, kernel_initializer='he_normal', padding='same')(conv5)
    conv6 = L.Concatenate()([up6, conv4])
    conv6 = L.Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
    conv6 = L.Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
    conv6 = L.Dropout(0.4)(conv6)

    up7 = L.Conv2DTranspose(256, 2, strides=2, kernel_initializer='he_normal', padding='same')(conv6)
    conv7 = L.Concatenate()([up7, conv3])
    conv7 = L.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
    conv7 = L.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
    conv7 = L.Dropout(0.3)(conv7)

    up8 = L.Conv2DTranspose(128, 2, strides=2, kernel_initializer='he_normal', padding='same')(conv7)
    conv8 = L.Concatenate()([up8, conv2])
    conv8 = L.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)
    conv8 = L.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)
    conv8 = L.Dropout(0.2)(conv8)

    up9 = L.Conv2DTranspose(64, 2, strides=2, kernel_initializer='he_normal', padding='same')(conv8)
    conv9 = L.Concatenate()([up9, conv1])
    conv9 = L.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv9 = L.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv9 = L.Dropout(0.1)(conv9)

    c10 = L.Conv2D(n_classes, 1, activation='softmax', kernel_initializer='he_normal')(conv9)

    # Mimic crossentropy loss
    c11 = L.Lambda(lambda x: x / tf.reduce_sum(x, len(x.get_shape()) - 1, True))(c10)
    c11 = L.Lambda(lambda x: tf.clip_by_value(x, _epsilon, 1. - _epsilon))(c11)
    c11 = L.Lambda(lambda x: K.log(x))(c11)
    weighted_sm = L.multiply([c11, weight_ip])

    model = Model(inputs=[ip, weight_ip], outputs=[weighted_sm])
    return model

I then compile and fit the model as is shown below:

model = make_weighted_loss_unet((256, 256, 3), 1) # shape of input, number of classes
model.compile(optimizer='adam', loss=my_loss, metrics=['acc'])[x_train_val, y_weights], y_train_val, validation_split=0.1, epochs=1)

The model can then train as usual. However, the loss doesn't seem to improve much. Furthermore, when I try to predict on new images, I obviously don't have the weight maps (because they are calculated on the labeled masks). I tried to use empty / zero arrays shaped like the weight map but that only yields in blank / zero predictions. I also tried different metrics and more standards losses without any success.

Did anyone face the same issue or have an alternative in implementing this weighted loss? Thanks in advance. BBQuercus

Upvotes: 4

Views: 3420

Answers (1)


Reputation: 2930

A simpler way to write custom loss with pixel weights

In your code, the loss is scattered around, between my_loss and make_weighted_loss_unet functions. You can add targets as an input and use model.add_loss to structure the code better :

def make_weighted_loss_unet(input_shape, n_classes):
    ip = L.Input(shape=input_shape)
    weight_ip = L.Input(shape=input_shape[:2] + (n_classes,))
    targets   = L.input(shape=input_shape[:2] + (n_classes,))
    # .... rest of your model definition code ...

    c10 = L.Conv2D(n_classes, 1, activation='softmax', kernel_initializer='he_normal')(conv9)
    model.add_loss(pixel_weighted_cross_entropy(weights_ip, targets, c10))
    # .... return Model .... NO NEED to specify loss in model.compile

def pixel_weighted_cross_entropy(weights, targets, predictions)
    loss_val = keras.losses.categorical_crossentropy(targets, predictions)
    weighted_loss_val = weights * loss_val
    return K.mean(weighted_loss_val)

If you don't refactor your code to the above approach, next section shows how to still run inference without issues

How to run your model in inference

Option 1 : Use another Model object for inference

You can create a Model used for training and another used for inference. Both are largely the same except that the inference Model does not take weights_ip, and gives an early output c10.

Here's an example code that adds an argument is_training=True to decide which Model to return :

def make_weighted_loss_unet(input_shape, n_classes, is_training=True):
    ip = L.Input(shape=input_shape)

    conv1 = L.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(ip)
    # .... rest of your model definition code ...
    c10 = L.Conv2D(n_classes, 1, activation='softmax', kernel_initializer='he_normal')(conv9)

    if is_training:
        # Mimic crossentropy loss
        c11 = L.Lambda(lambda x: x / tf.reduce_sum(x, len(x.get_shape()) - 1, True))(c10)
        c11 = L.Lambda(lambda x: tf.clip_by_value(x, _epsilon, 1. - _epsilon))(c11)
        c11 = L.Lambda(lambda x: K.log(x))(c11)
        weight_ip = L.Input(shape=input_shape[:2] + (n_classes,))
        weighted_sm = L.multiply([c11, weight_ip])
        return Model(inputs=[ip, weight_ip], outputs=[weighted_sm])
        return Model(inputs=[ip], outputs=[c10]) 
    return model

Option 2 : Use K.function

If you don't want to mess with your Model definition method (make_weighted_loss_unet) and want to achieve the same result outside, you can use a function that extracts the subgraph relevant for inference.

In your inference function:

from keras import backend as K

model = make_weighted_loss_unet(input_shape, n_classes)
inference_function = K.function([model.get_layer("input_layer").input], 
predicted_heatmap = inference_function(new_image)

Note that you'll have to give name= to your ip layer and c10 layer to be able to retrieve them via model.get_layer(name) :

ip = L.Input(shape=input_shape, name="input_layer")


c10 = L.Conv2D(n_classes, 1, activation='softmax', kernel_initializer='he_normal', name="output_softmax_layer")(conv9)

Upvotes: 5

Related Questions