Reputation: 53
I am working on a semantic segmentation problem where I have an input x ( CT image) to a deep learning model of shape (batch_size,1,256,256)
and an output of shape (batch_size,2,256,256)
where the first channel represents an output mask (bone mask), and the second represents a second output mask ( lesion mask). I am using a combined loss function for every channel output , which is a combination of wighted BCE and soft dice loss both focussing on the foreground pixels :
class Custom_Loss(tf.keras.losses.Loss):
def __init__(self,w1 = 0.3, w2 = 0.7 , w3 = 0.4, w4 = 0.6, w5 = 1.6 , reduction="sum_over_batch_size" ):
"""
w1 : weight for the bone loss contribution to the total loss.
w2 : Weight for the lession loss contribution to the total loss.
w3 : Weight for the soft dice loss contribution to the combined loss.
w4 : Weight for the BCE contribution to the combined loss.
w5: Weight for the foreground pixels in the BCE loss
"""
self.w1 = w1
self.w2 = w2
self.w3 = w3
self.w4 = w4
self.w5 = w5
self.reduction = reduction
super().__init__(reduction = reduction)
def bce(self,y_true, y_pred):
epsilon = 1e-8 # Add small epsilon to avoid log(0)
# Compute the total number of pixels
N = y_true.shape[1] * y_true.shape[2]
# Compute the BCE loss per image
bce_loss = (-1 / N) * tf.reduce_sum((self.w5 *y_true * tf.math.log(y_pred + epsilon)) + ((1 - y_true) * tf.math.log(1 - y_pred + epsilon)),axis = (1,2))
# Average the loss over the batch
#bce_loss = tf.reduce_mean(bce_loss)
return bce_loss
def soft_dice_loss(self,y_true, y_pred):
epsilon = 1e-8 # Add small epsilon to avoid division by zero
# Calculate the numerator and denominator
numerator_dice_coef= 2 * tf.reduce_sum(y_true * y_pred, axis=(1, 2)) + epsilon
den_dice_coef = (tf.reduce_sum(y_true * y_true, axis=(1, 2))) + (tf.reduce_sum(y_pred * y_pred, axis=(1, 2))) + epsilon
# Dice coefficient per image in the batch
dice_coef = numerator_dice_coef / den_dice_coef
# Average Dice coefficient over the batch
#mean_dice_coef = tf.reduce_mean(dice_coef)
return 1 - dice_coef
def combined_loss(self, y_true, y_pred):
loss = (self.w3 * self.soft_dice_loss(y_true, y_pred) ) + (self.w4 * self.bce(y_true, y_pred))
return loss
def call(self, y_true, y_pred):
bone_pred = y_pred [:,0,:,:]
lesion_pred = y_pred [:,1,:,:]
bone_ground_truth = y_true [:,0,:,:]
lesion_ground_truth = y_true [:,1,:,:]
#loss = (self.w1 * self.combined_loss(bone_ground_truth, bone_pred ) ) + (self.w2 * self.combined_loss(lession_ground_truth, lession_pred) )
# Compute combined loss for bone and lesion masks
bone_loss = self.combined_loss(bone_ground_truth, bone_pred) # Shape: (batch_size,)
lesion_loss = self.combined_loss(lesion_ground_truth, lesion_pred) # Shape: (batch_size,)
# Total loss per sample
# Use tf.multiply for weighted sum
weighted_bone_loss = tf.multiply(self.w1, bone_loss) # Shape: (batch_size,)
weighted_lesion_loss = tf.multiply(self.w2, lesion_loss) # Shape: (batch_size,)
# Add weighted losses
total_loss = weighted_bone_loss + weighted_lesion_loss # Shape: (batch_size,)
# Store loss components for logging
#self.last_bone_loss = tf.reduce_mean(bone_loss)
#self.last_lesion_loss = tf.reduce_mean(lesion_loss)
return total_loss
I am following the keras documentation where the call function returns the loss per sample. Now I wrote a custom callback to log the validation loss components ( for every mask, bone and lesion) as well as the total loss, as follows :
# Callback for logging loss components at the end of each epoch for the validation data
class LossLoggerCallback(tf.keras.callbacks.Callback):
def __init__(self, loss_fn, validation_data):
"""
Callback to log loss components for validation data at the end of each epoch.
Args:
loss_fn: Custom loss function (instance of `CustomLoss`).
validation_data: Validation dataset (can be a tf.data.Dataset).
"""
super().__init__()
self.loss_fn = loss_fn
self.validation_data = validation_data
def on_epoch_end(self, epoch, logs=None):
# Initialize lists to accumulate losses
total_losses = []
bone_losses = []
lesion_losses = []
# Initialize a counter to keep track of the total number of samples
total_samples = 0
# Iterate over all batches in the validation data
for x_val, y_val in self.validation_data:
batch_size = x_val.shape[0] # Get the batch size
# Make predictions
y_pred = self.model.predict(x_val, verbose=0)
# Extract bone and lesion predictions and ground truths
bone_pred = y_pred[:, 0, :, :]
lesion_pred = y_pred[:, 1, :, :]
bone_gt = y_val[:, 0, :, :]
lesion_gt = y_val[:, 1, :, :]
# Compute combined loss for bone and lesion (this gives a batch-wise loss)
bone_loss = self.loss_fn.combined_loss(bone_gt, bone_pred) # Shape: (batch_size,)
lesion_loss = self.loss_fn.combined_loss(lesion_gt, lesion_pred) # Shape: (batch_size,)
# Apply weighting as in `call`
weighted_bone_loss = tf.multiply(self.loss_fn.w1, bone_loss) # Shape: (batch_size,)
weighted_lesion_loss = tf.multiply(self.loss_fn.w2, lesion_loss) # Shape: (batch_size,)
# Total loss per sample in the batch
total_loss = weighted_bone_loss + weighted_lesion_loss # Shape: (batch_size,)
# Accumulate batch-wise losses for averaging later
total_losses.extend(total_loss.numpy()) # Add individual losses per sample
bone_losses.extend(bone_loss.numpy()) # Add individual bone losses
lesion_losses.extend(lesion_loss.numpy()) # Add individual lesion losses
# Update total number of samples processed
total_samples += batch_size
# Compute the mean loss across the entire validation dataset
mean_bone_loss = np.sum(bone_losses) / total_samples
mean_lesion_loss = np.sum(lesion_losses) / total_samples
mean_total_loss = np.sum(total_losses) / total_samples
# Print the results for the current epoch
print(f"Epoch {epoch + 1}: Validation Bone Loss = {mean_bone_loss:.4f}, "
f"Validation Lesion Loss = {mean_lesion_loss:.4f}, "
f"Validation Total Loss = {mean_total_loss:.4f}")
The batch size used is 4 , so the shape of one batch in the validation is (4,2, 256,256)
for y and for x it is (4,1,256,256)
. Like you see in the custom loss class, i am using the reduction sum_over_batch_size
. When I train the model in a multi-gpu environment :
with strategy.scope():
loss_fn = Custom_Loss()
# Initialize the model
model= create_model()
model.compile(optimizer=Adam(learning_rate=1e-4,beta_1 = 0.999, beta_2 = 0.999),
loss= loss_fn,
metrics=[IoU]
)
val_loss_logger = LossLoggerCallback(loss_fn, validation_data=val_dataset)
es = EarlyStopping(monitor='val_io_u', mode='max', verbose=1, patience=40)
mc = ModelCheckpoint('/kaggle/working/best_model.keras', monitor='val_io_u', mode='max', verbose=1, save_best_only=True)
# Train the model
history = model.fit(x = train_dataset,
batch_size= batch_size,
validation_data = val_dataset,
epochs= epochs,
steps_per_epoch= steps_per_epoch,
callbacks=[es,mc,val_loss_logger])
I get:
Epoch 1/100
237/237 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - io_u: 0.1837 - loss: 272.1821
Epoch 1: val_io_u improved from -inf to 0.00000, saving model to /kaggle/working/best_model.keras
Epoch 1: Validation Bone Loss = 0.6672, Validation Lesion Loss = 0.6553, Validation Total Loss = 0.6589
237/237 ━━━━━━━━━━━━━━━━━━━━ 443s 1s/step - io_u: 0.1842 - loss: 271.8742 - val_io_u: 0.0000e+00 - val_loss: 83.1071
Epoch 2/100
237/237 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - io_u: 0.3941 - loss: 56.6352
Epoch 2: val_io_u did not improve from 0.00000
Epoch 2: Validation Bone Loss = 0.5375, Validation Lesion Loss = 0.5063, Validation Total Loss = 0.5157
237/237 ━━━━━━━━━━━━━━━━━━━━ 339s 1s/step - io_u: 0.3941 - loss: 56.5627 - val_io_u: 0.0000e+00 - val_loss: 19.8229
Why Validation Total Loss
returned by the callback is not equal to val_loss
returned by keras. I think I have a misunderstanding regarding how keras calculates the loss using the reduction sum_over_batch_size
. I want the loss returned by the callback to be identical to that calculated by keras.
The loss function I am using is :
Loss = w1 * Loss_combined_bone_mask + w2 * Loss_combined_esion_mask where,
Loss_combined_bone_mask = w3* soft_dice_loss(y_pred_bone,y_true_bone) + w4 * BCE(y_pred_bone,y_true_bone)
and
Loss_combined_lesion_mask = w3* soft_dice_loss(y_pred_lesion,y_true_bone) + w4 * BCE(y_pred_lesion,y_true_bone)
Upvotes: 2
Views: 64