Reputation: 2688
I'm training a model which inputs and outputs images with same shape (H, W, C)
in RGB color space.
My loss function is MSE over these images, but in another color space.
The color space conversion is defined by transform_space
function, which takes and returns one image.
I'm inheriting tf.keras.losses.Loss
to implement this loss.
The method call
however takes images not one by one, but in batches of shape (None, H, W, C)
.
The problem is the first dimension of this batch is None
.
I was trying to iterate through these batches, but got error iterating over tf.Tensor is not allowed
.
So, how should I calculate my loss?
The reasons why I can't use a new color space as input and output for my model:
tf.keras.applications
which works with RGB
I'm using tf.distribute.MirroredStrategy
if it matters.
# Takes an image of shape (H, W, C),
# converts it to a new color space
# and returns a new image with shape (H, W, C)
def transform_space(image):
# ...color space transformation...
return image_in_a_new_color_space
class MyCustomLoss(tf.keras.losses.Loss):
def __init__(self):
super().__init__()
# The loss function is defined this way
# due to the fact that I use "tf.distribute.MirroredStrategy"
mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
self.loss_fn = lambda true, pred: tf.math.reduce_mean(mse(true, pred))
def call(self, true_batch, pred_batch):
# Since shape of true/pred_batch is (None, H, W, C)
# and transform_space expects shape (H, W, C)
# the following transformations are impossible:
true_batch_transformed = transform_space(true_batch)
pred_batch_transformed = transform_space(pred_batch)
return self.loss_fn(true_batch_transformed, pred_batch_transformed)
Upvotes: 1
Views: 1052
Reputation: 1941
Batching is basically hard coded into TF's design. It's the best way to take advantage of GPU resources to run deep learning models fast. Looping is strongly discouraged in TF for the same reason - the whole point of using TF is vectorization: running many computations in parallel.
It's possible to break these design assumptions. But really the correct way to do this is to implement your transform in a vectorized way (e.g. make transform_space() takes batches).
FYI TF natively supports YUV, YUQ, and HSV conversions in the tf.image package, in case you were using one of those. Or, you can look at the source there and see if you can adapt it to your needs.
Anyways, to do what you want, but with a potentially serious performance hit, you want to use tf.map_fn.
true_batch_transformed = tf.map_fn(transform_space, true_batch)
pred_batch_transformed = tf.map_fn(transform_space, pred_batch)
Upvotes: 1