Legotin
Legotin

Reputation: 2688

Custom loss function in TensorFlow 2: dealing with None in batch dimension

I'm training a model which inputs and outputs images with same shape (H, W, C) in RGB color space.

My loss function is MSE over these images, but in another color space.
The color space conversion is defined by transform_space function, which takes and returns one image.

I'm inheriting tf.keras.losses.Loss to implement this loss.
The method call however takes images not one by one, but in batches of shape (None, H, W, C).
The problem is the first dimension of this batch is None.

I was trying to iterate through these batches, but got error iterating over tf.Tensor is not allowed.
So, how should I calculate my loss?

The reasons why I can't use a new color space as input and output for my model:

I'm using tf.distribute.MirroredStrategy if it matters.

# Takes an image of shape (H, W, C),
# converts it to a new color space
# and returns a new image with shape (H, W, C)
def transform_space(image):
  # ...color space transformation...
  return image_in_a_new_color_space

class MyCustomLoss(tf.keras.losses.Loss):

  def __init__(self):
    super().__init__()

    # The loss function is defined this way
    # due to the fact that I use "tf.distribute.MirroredStrategy"
    mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
    self.loss_fn = lambda true, pred: tf.math.reduce_mean(mse(true, pred))

  def call(self, true_batch, pred_batch):

    # Since shape of true/pred_batch is (None, H, W, C)
    # and transform_space expects shape (H, W, C)
    # the following transformations are impossible:
    true_batch_transformed = transform_space(true_batch)
    pred_batch_transformed = transform_space(pred_batch)

    return self.loss_fn(true_batch_transformed, pred_batch_transformed)

Upvotes: 1

Views: 1052

Answers (1)

Yaoshiang
Yaoshiang

Reputation: 1941

Batching is basically hard coded into TF's design. It's the best way to take advantage of GPU resources to run deep learning models fast. Looping is strongly discouraged in TF for the same reason - the whole point of using TF is vectorization: running many computations in parallel.

It's possible to break these design assumptions. But really the correct way to do this is to implement your transform in a vectorized way (e.g. make transform_space() takes batches).

FYI TF natively supports YUV, YUQ, and HSV conversions in the tf.image package, in case you were using one of those. Or, you can look at the source there and see if you can adapt it to your needs.

Anyways, to do what you want, but with a potentially serious performance hit, you want to use tf.map_fn.

true_batch_transformed = tf.map_fn(transform_space, true_batch)
pred_batch_transformed = tf.map_fn(transform_space, pred_batch)

Upvotes: 1

Related Questions