Sasha Korekov
Sasha Korekov

Reputation: 343

Keras, binary segmentation, add weight to loss function

I'm solving a binary segmentation problem with Keras (w. tf backend). How can I add more weight to the center of each area of mask?

I've tried dice coef with added cv2.erode(), but it doesn't work

def dice_coef_eroded(y_true, y_pred):
    kernel = (3, 3)
    y_true = cv2.erode(y_true.eval(), kernel, iterations=1)
    y_pred = cv2.erode(y_pred.eval(), kernel, iterations=1)
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1)

Keras 2.1.3, tensorflow 1.4

Upvotes: 1

Views: 2410

Answers (2)

eLearner
eLearner

Reputation: 71

I'm implementing this solution but I wonder what should be the ground truth that we must give to the network. That is, now the output is the loss, and we want the loss to be 0, so should we train the network as follows?

model = get_unet_w_lambda_loss()
model.fit([inputs, weights, masks], zero_images)

Upvotes: 0

Sasha Korekov
Sasha Korekov

Reputation: 343

All right, the solution I found is following:

1) Create in your Iterator a method to retrieve weights' matrix (with shape = mask shape). The output must contain [image, mask, weights]

2) Create a Lambda layer containing loss function

3) Create an Identity loss function

Example:

def weighted_binary_loss(X):
    import keras.backend as K
    import keras.layers.merge as merge
    y_pred, weights, y_true = X
    loss = K.binary_crossentropy(y_pred, y_true)
    loss = merge([loss, weights], mode='mul')
    return loss

def identity_loss(y_true, y_pred):
    return y_pred

def get_unet_w_lambda_loss(input_shape=(1024, 1024, 3), mask_shape=(1024, 1024, 1)):
    images = Input(input_shape)
    mask_weights = Input(mask_shape)
    true_masks = Input(mask_shape)
    ...
    y_pred = Conv2D(1, (1, 1), activation='sigmoid')(up1) #output of original unet
    loss = Lambda(weighted_binary_loss, output_shape=(1024, 1024, 1))([y_pred, mask_weights, true_masks])
    model = Model(inputs=[images, mask_weights, true_masks], outputs=loss)

Upvotes: 3

Related Questions