Loann Gio
Loann Gio

Reputation: 85

How to prevent Keras from computing metrics during training

I'm using Tensorflow/Keras 2.4.1 and I have a (unsupervised) custom metric that takes several of my model inputs as parameters such as:

model = build_model() # returns a tf.keras.Model object
my_metric = custom_metric(model.output, model.input[0], model.input[1])
model.add_metric(my_metric)
[...]
model.fit([...]) # training with fit

However, it happens that custom_metric is very expensive so I would like it to be computed during validation only. I found this answer but I hardly understand how I can adapt the solution to my metric that uses several model inputs as parameter since the update_state method doesn't seem flexible.

In my context, is there a way to avoid computing my metric during training, aside from writing my own training loop ? Also, I am very surprised we cannot natively specify to Tensorflow that some metrics should only be computed at validation time, is there a reason for that ?

In addition, since the model is trained to optimize the loss, and that the training dataset should not be used to evaluate a model, I don't even understand why, by default, Tensorflow computes metrics during training.

Upvotes: 8

Views: 1541

Answers (4)

user160623
user160623

Reputation: 316

Compiling the model again at testing time allows for the metrics to be changed. This deletes the optimizer state, but if you are not training the model further then this is not an issue. This approach saves writing and running a callback function.

training_metrics = [...]
testing_metrics = [...]

model = model.compile(..., metrics=training_metrics)
model.fit(...)
model = model.compile(..., metrics=testing_metrics)
model.evaluate(...)

Upvotes: 0

Nicolas Pinchaud
Nicolas Pinchaud

Reputation: 21

Since the metrics are being run within the train_step function of keras.Model, filtering out train disabled metrics without altering the API requires to subclass keras.Model.

We define a simple metric wrapper:

class TrainDisabledMetric(Metric):

  def __init__(self, metric: Metric):
    super().__init__(name=metric.name)
    self._metric = metric

  def update_state(self, *args, **kwargs):
    return self._metric.update_state(*args, **kwargs)

  def reset_state(self):
    return self._metric.reset_state()

  def result(self):
    return self._metric.result()

and subclass keras.Model to filter out those metrics during train:

class CustomModel(keras.Model):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

  def compile(self, optimizer='rmsprop', loss=None, metrics=None,
              loss_weights=None, weighted_metrics=None, run_eagerly=None,
              steps_per_execution=None, jit_compile=None, **kwargs):

    from_serialized = kwargs.get('from_serialized', False)

    super().compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights,
                    weighted_metrics=weighted_metrics, run_eagerly=run_eagerly,
                    steps_per_execution=steps_per_execution,
                    jit_compile=jit_compile, **kwargs)

    self.on_train_compiled_metrics = self.compiled_metrics

    if metrics is not None:

      def get_on_train_traverse_tree(structure):
        flat = tf.nest.flatten(structure)
        on_train = [not isinstance(e, TrainDisabledMetric) for e in flat]
        full_tree = tf.nest.pack_sequence_as(structure, on_train)
        return get_traverse_shallow_structure(lambda s: any(tf.nest.flatten(s)),
                                              full_tree)

      on_train_sub_tree = get_on_train_traverse_tree(metrics)
      flat_on_train = flatten_up_to(on_train_sub_tree, metrics)

      def clean_tree(tree):
        if isinstance(tree, list):
          _list = []
          for t in tree:
            r = clean_tree(t)
            if r:
              _list.append(r)
          return _list

        elif isinstance(tree, dict):
          _tree = {}
          for k, v in tree.items():
            r = clean_tree(v)
            if r:
              _tree[k] = r
          return _tree
        else:
          return tree

      pruned_on_train_sub_tree = clean_tree(on_train_sub_tree)
      pruned_flat_on_train = [m for keep, m in
                              zip(tf.nest.flatten(on_train_sub_tree),
                                  flat_on_train) if keep]

      on_train_metrics = tf.nest.pack_sequence_as(pruned_on_train_sub_tree,
                                                  pruned_flat_on_train)

      self.on_train_compiled_metrics = compile_utils.MetricsContainer(
        on_train_metrics, weighted_metrics=None, output_names=self.output_names,
        from_serialized=from_serialized)

  def train_step(self, data):
    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
    # Run forward pass.
    with tf.GradientTape() as tape:
      y_pred = self(x, training=True)
      loss = self.compute_loss(x, y, y_pred, sample_weight)
    self._validate_target_and_loss(y, loss)
    # Run backwards pass.
    self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    return self.compute_metrics(x, y, y_pred, sample_weight, training=True)

  def compute_metrics(self, x, y, y_pred, sample_weight, training=False):
    del x  # The default implementation does not use `x`.

    if training:
      self.on_train_compiled_metrics.update_state(y, y_pred, sample_weight)
      metrics = self.on_train_metrics
    else:
      self.compiled_metrics.update_state(y, y_pred, sample_weight)
      metrics = self.metrics
    # Collect metrics to return
    return_metrics = {}
    for metric in metrics:
      result = metric.result()
      if isinstance(result, dict):
        return_metrics.update(result)
      else:
        return_metrics[metric.name] = result
    return return_metrics

  @property
  def on_train_metrics(self):
    metrics = []
    if self._is_compiled:
      # TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects
      # so that attr names are not load-bearing.
      if self.compiled_loss is not None:
        metrics += self.compiled_loss.metrics
      if self.on_train_compiled_metrics is not None:
        metrics += self.on_train_compiled_metrics.metrics

    for l in self._flatten_layers():
      metrics.extend(l._metrics)  # pylint: disable=protected-access
    return metrics

Now given a keras model, we can wrap it and compile it with train disabled metrics:

model: keras.Model = ...
custom_model = CustomModel(inputs=model.input, outputs=model.output)

train_enabled_metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]

# wrap train disabled metrics with `TrainDisabledMetric`:
train_disabled_metrics = [
  TrainDisabledMetric(tf.keras.metrics.SparseCategoricalCrossentropy())]

metrics = train_enabled_metrics + train_disabled_metrics

custom_model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
                     loss=tf.keras.losses.SparseCategoricalCrossentropy(
                       from_logits=True), metrics=metrics, )

custom_model.fit(ds_train, epochs=6, validation_data=ds_test, )

The metric SparseCategoricalCrossentropy is computed only during validation:

Epoch 1/6
469/469 [==============================] - 2s 2ms/step - loss: 0.3522 - sparse_categorical_accuracy: 0.8366 - val_loss: 0.1978 - val_sparse_categorical_accuracy: 0.9086 - val_sparse_categorical_crossentropy: 1.3197
Epoch 2/6
469/469 [==============================] - 1s 1ms/step - loss: 0.1631 - sparse_categorical_accuracy: 0.9526 - val_loss: 0.1429 - val_sparse_categorical_accuracy: 0.9587 - val_sparse_categorical_crossentropy: 1.1910
Epoch 3/6
469/469 [==============================] - 1s 1ms/step - loss: 0.1178 - sparse_categorical_accuracy: 0.9654 - val_loss: 0.1139 - val_sparse_categorical_accuracy: 0.9661 - val_sparse_categorical_crossentropy: 1.1369
Epoch 4/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0909 - sparse_categorical_accuracy: 0.9735 - val_loss: 0.0981 - val_sparse_categorical_accuracy: 0.9715 - val_sparse_categorical_crossentropy: 1.0434
Epoch 5/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0735 - sparse_categorical_accuracy: 0.9784 - val_loss: 0.0913 - val_sparse_categorical_accuracy: 0.9721 - val_sparse_categorical_crossentropy: 0.9862
Epoch 6/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0606 - sparse_categorical_accuracy: 0.9823 - val_loss: 0.0824 - val_sparse_categorical_accuracy: 0.9761 - val_sparse_categorical_crossentropy: 1.0024

Upvotes: 2

Marco Cerliani
Marco Cerliani

Reputation: 22031

I think that the simplest solution to compute a metric only on the validation is using a custom callback.

here we define our dummy callback:

class MyCustomMetricCallback(tf.keras.callbacks.Callback):

    def __init__(self, train=None, validation=None):
        super(MyCustomMetricCallback, self).__init__()
        self.train = train
        self.validation = validation

    def on_epoch_end(self, epoch, logs={}):

        mse = tf.keras.losses.mean_squared_error

        if self.train:
            logs['my_metric_train'] = float('inf')
            X_train, y_train = self.train[0], self.train[1]
            y_pred = self.model.predict(X_train)
            score = mse(y_train, y_pred)
            logs['my_metric_train'] = np.round(score, 5)

        if self.validation:
            logs['my_metric_val'] = float('inf')
            X_valid, y_valid = self.validation[0], self.validation[1]
            y_pred = self.model.predict(X_valid)
            val_score = mse(y_pred, y_valid)
            logs['my_metric_val'] = np.round(val_score, 5)

Given this dummy model:

def build_model():

  inp1 = Input((5,))
  inp2 = Input((5,))
  out = Concatenate()([inp1, inp2])
  out = Dense(1)(out)

  model = Model([inp1, inp2], out)
  model.compile(loss='mse', optimizer='adam')

  return model

and this data:

X_train1 = np.random.uniform(0,1, (100,5))
X_train2 = np.random.uniform(0,1, (100,5))
y_train = np.random.uniform(0,1, (100,1))

X_val1 = np.random.uniform(0,1, (100,5))
X_val2 = np.random.uniform(0,1, (100,5))
y_val = np.random.uniform(0,1, (100,1))

you can use the custom callback to compute the metric both on train and validation:

model = build_model()

model.fit([X_train1, X_train2], y_train, epochs=10, 
          callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train), validation=([X_val1, X_val2],y_val))])

only on validation:

model = build_model()

model.fit([X_train1, X_train2], y_train, epochs=10, 
          callbacks=[MyCustomMetricCallback(validation=([X_val1, X_val2],y_val))])

only on train:

model = build_model()

model.fit([X_train1, X_train2], y_train, epochs=10, 
          callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train))])

remember only that the callback evaluates the metrics one-shot on the data, like any metric/loss computed by default by keras on the validation_data.

here is the running code.

Upvotes: 3

Daniel Möller
Daniel Möller

Reputation: 86610

I was able to use learning_phase but only in symbolic tensor mode (graph) mode:

So, at first we need to disable eager mode (this must be done right after importing tensorflow):

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

Then you can create your metric using a symbolic if (backend.switch):

def metric_graph(in1, in2, out):
    actual_metric = out * (in1 + in2)
    return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric) 

The method add_metric will ask for a name and an aggregation method, which you can set to "mean".

So, here is one example:

x1 = numpy.ones((5,3))
x2 = numpy.ones((5,3))
y = 3*numpy.ones((5,1))

vx1 = numpy.ones((5,3))
vx2 = numpy.ones((5,3))
vy = 3*numpy.ones((5,1))

def metric_eager(in1, in2, out):
    if (K.learning_phase()):
        return 0
    else:
        return out * (in1 + in2)

def metric_graph(in1, in2, out):
    actual_metric = out * (in1 + in2)
    return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric) 

ins1 = Input((3,))
ins2 = Input((3,))
outs = Concatenate()([ins1, ins2])
outs = Dense(1)(outs)
model = Model([ins1, ins2],outs)
model.add_metric(metric_graph(ins1, ins2, outs), name='my_metric', aggregation='mean')
model.compile(loss='mse', optimizer='adam')

model.fit([x1, x2],y, validation_data=([vx1, vx2], vy), epochs=3)

Upvotes: 3

Related Questions