Kongsea
Kongsea

Reputation: 927

How to define a weighted loss function in TensorFlow?

I have a training dataset of train_data and train_labels which is train_data_node and train_labels_node in the graph of tensorflow. As you know, I can use the loss function of tensorflow as bellows:

logits = model(train_data_node)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits,train_labels_node))

However, this loss function processes all the training data equally. But in our situation, we want to process the data discriminately. For example, we have a csv file corresponding to the training data to indicate the train data is original or augmented. Then we want to define a custom loss function which makes the loss of original data play more important role and the loss of augmented data play less important role, such as:

loss_no_aug = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(noAugLogits, noAugLabels))
loss_aug = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(augLogits, augLabels))
loss = loss_no_aug * PENALTY_COEFFICIENT + loss_aug

I have defined a loss function as bellow, but it didn't work:

def calLoss(logits, labels, augs):
  noAugLogits = []
  noAugLabels = []
  augLogits = []
  augLabels = []
  tf.get_collection()
  for i in range(augs.shape[0]):
    if augs[i] == 1:
      noAugLogits.append(logits[i])
      noAugLabels.append(labels[i])
    else:
      augLogits.append(logits[i])
      augLabels.append(labels[i])
  noAugLogits = tf.convert_to_tensor(noAugLogits)
  noAugLabels = tf.convert_to_tensor(noAugLabels)
  augLogits = tf.convert_to_tensor(augLogits)
  augLabels = tf.convert_to_tensor(augLabels)
  return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
      noAugLogits, noAugLabels)) * PENALTY_COEFFICIENT + \
      tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(augLogits, augLabels))

I think we should write the loss function using tensor operations, however, I am not familiar with them. So could anyone give me some advice on how to define the loss function.

Thank you for your kind answers or suggestions.

Upvotes: 8

Views: 3973

Answers (1)

Kongsea
Kongsea

Reputation: 927

I have finally solved the problem by myself using the function tf.boolen_mask() of tensorflow. The defined custom weighted loss function is as bellows:

def calLoss(logits, labels, augs):
  augSum = tf.reduce_sum(augs)
  pred = tf.less(augSum, BATCH_SIZE)

  def noaug(logits, labels, augs):
    augs = tf.cast(augs, tf.bool)
    noaugs = tf.logical_not(augs)
    noAugLogits = tf.boolean_mask(logits, noaugs)
    noAugLabels = tf.boolean_mask(labels, noaugs)
    augLogits = tf.boolean_mask(logits, augs)
    augLabels = tf.boolean_mask(labels, augs)
    noaugLoss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(noAugLogits, noAugLabels))
    augLoss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(augLogits, augLabels))
    return noaugLoss * PENALTY_COEFFICIENT + augLoss

  def aug(logits, labels):
    return tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels))

  return tf.cond(pred, lambda: noaug(logits, labels, augs), lambda: aug(logits, labels))

As you can see, I use a numpy array variable, augs, using 1 and 0 in corresponding positions to indicate whether a sample in a batch of samples is augmented or non-augmented. Then I convert the variable to a bool tensor and use it as the bool mask of tf.boolen_mask() to fetch the augmented and non-augmented samples and calculate the loss respectively.

Upvotes: 1

Related Questions