Reputation: 85
I'm using Tensorflow/Keras 2.4.1 and I have a (unsupervised) custom metric that takes several of my model inputs as parameters such as:
model = build_model() # returns a tf.keras.Model object
my_metric = custom_metric(model.output, model.input[0], model.input[1])
model.add_metric(my_metric)
[...]
model.fit([...]) # training with fit
However, it happens that custom_metric
is very expensive so I would like it to be computed during validation only. I found this answer but I hardly understand how I can adapt the solution to my metric that uses several model inputs as parameter since the update_state
method doesn't seem flexible.
In my context, is there a way to avoid computing my metric during training, aside from writing my own training loop ? Also, I am very surprised we cannot natively specify to Tensorflow that some metrics should only be computed at validation time, is there a reason for that ?
In addition, since the model is trained to optimize the loss, and that the training dataset should not be used to evaluate a model, I don't even understand why, by default, Tensorflow computes metrics during training.
Upvotes: 8
Views: 1541
Reputation: 316
Compiling the model again at testing time allows for the metrics to be changed. This deletes the optimizer state, but if you are not training the model further then this is not an issue. This approach saves writing and running a callback function.
training_metrics = [...]
testing_metrics = [...]
model = model.compile(..., metrics=training_metrics)
model.fit(...)
model = model.compile(..., metrics=testing_metrics)
model.evaluate(...)
Upvotes: 0
Reputation: 21
Since the metrics are being run within the train_step
function of keras.Model
, filtering out train disabled metrics without altering the API requires to subclass keras.Model
.
We define a simple metric wrapper:
class TrainDisabledMetric(Metric):
def __init__(self, metric: Metric):
super().__init__(name=metric.name)
self._metric = metric
def update_state(self, *args, **kwargs):
return self._metric.update_state(*args, **kwargs)
def reset_state(self):
return self._metric.reset_state()
def result(self):
return self._metric.result()
and subclass keras.Model
to filter out those metrics during train:
class CustomModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def compile(self, optimizer='rmsprop', loss=None, metrics=None,
loss_weights=None, weighted_metrics=None, run_eagerly=None,
steps_per_execution=None, jit_compile=None, **kwargs):
from_serialized = kwargs.get('from_serialized', False)
super().compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights,
weighted_metrics=weighted_metrics, run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
jit_compile=jit_compile, **kwargs)
self.on_train_compiled_metrics = self.compiled_metrics
if metrics is not None:
def get_on_train_traverse_tree(structure):
flat = tf.nest.flatten(structure)
on_train = [not isinstance(e, TrainDisabledMetric) for e in flat]
full_tree = tf.nest.pack_sequence_as(structure, on_train)
return get_traverse_shallow_structure(lambda s: any(tf.nest.flatten(s)),
full_tree)
on_train_sub_tree = get_on_train_traverse_tree(metrics)
flat_on_train = flatten_up_to(on_train_sub_tree, metrics)
def clean_tree(tree):
if isinstance(tree, list):
_list = []
for t in tree:
r = clean_tree(t)
if r:
_list.append(r)
return _list
elif isinstance(tree, dict):
_tree = {}
for k, v in tree.items():
r = clean_tree(v)
if r:
_tree[k] = r
return _tree
else:
return tree
pruned_on_train_sub_tree = clean_tree(on_train_sub_tree)
pruned_flat_on_train = [m for keep, m in
zip(tf.nest.flatten(on_train_sub_tree),
flat_on_train) if keep]
on_train_metrics = tf.nest.pack_sequence_as(pruned_on_train_sub_tree,
pruned_flat_on_train)
self.on_train_compiled_metrics = compile_utils.MetricsContainer(
on_train_metrics, weighted_metrics=None, output_names=self.output_names,
from_serialized=from_serialized)
def train_step(self, data):
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# Run forward pass.
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compute_loss(x, y, y_pred, sample_weight)
self._validate_target_and_loss(y, loss)
# Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
return self.compute_metrics(x, y, y_pred, sample_weight, training=True)
def compute_metrics(self, x, y, y_pred, sample_weight, training=False):
del x # The default implementation does not use `x`.
if training:
self.on_train_compiled_metrics.update_state(y, y_pred, sample_weight)
metrics = self.on_train_metrics
else:
self.compiled_metrics.update_state(y, y_pred, sample_weight)
metrics = self.metrics
# Collect metrics to return
return_metrics = {}
for metric in metrics:
result = metric.result()
if isinstance(result, dict):
return_metrics.update(result)
else:
return_metrics[metric.name] = result
return return_metrics
@property
def on_train_metrics(self):
metrics = []
if self._is_compiled:
# TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects
# so that attr names are not load-bearing.
if self.compiled_loss is not None:
metrics += self.compiled_loss.metrics
if self.on_train_compiled_metrics is not None:
metrics += self.on_train_compiled_metrics.metrics
for l in self._flatten_layers():
metrics.extend(l._metrics) # pylint: disable=protected-access
return metrics
Now given a keras model, we can wrap it and compile it with train disabled metrics:
model: keras.Model = ...
custom_model = CustomModel(inputs=model.input, outputs=model.output)
train_enabled_metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
# wrap train disabled metrics with `TrainDisabledMetric`:
train_disabled_metrics = [
TrainDisabledMetric(tf.keras.metrics.SparseCategoricalCrossentropy())]
metrics = train_enabled_metrics + train_disabled_metrics
custom_model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True), metrics=metrics, )
custom_model.fit(ds_train, epochs=6, validation_data=ds_test, )
The metric SparseCategoricalCrossentropy
is computed only during validation:
Epoch 1/6
469/469 [==============================] - 2s 2ms/step - loss: 0.3522 - sparse_categorical_accuracy: 0.8366 - val_loss: 0.1978 - val_sparse_categorical_accuracy: 0.9086 - val_sparse_categorical_crossentropy: 1.3197
Epoch 2/6
469/469 [==============================] - 1s 1ms/step - loss: 0.1631 - sparse_categorical_accuracy: 0.9526 - val_loss: 0.1429 - val_sparse_categorical_accuracy: 0.9587 - val_sparse_categorical_crossentropy: 1.1910
Epoch 3/6
469/469 [==============================] - 1s 1ms/step - loss: 0.1178 - sparse_categorical_accuracy: 0.9654 - val_loss: 0.1139 - val_sparse_categorical_accuracy: 0.9661 - val_sparse_categorical_crossentropy: 1.1369
Epoch 4/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0909 - sparse_categorical_accuracy: 0.9735 - val_loss: 0.0981 - val_sparse_categorical_accuracy: 0.9715 - val_sparse_categorical_crossentropy: 1.0434
Epoch 5/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0735 - sparse_categorical_accuracy: 0.9784 - val_loss: 0.0913 - val_sparse_categorical_accuracy: 0.9721 - val_sparse_categorical_crossentropy: 0.9862
Epoch 6/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0606 - sparse_categorical_accuracy: 0.9823 - val_loss: 0.0824 - val_sparse_categorical_accuracy: 0.9761 - val_sparse_categorical_crossentropy: 1.0024
Upvotes: 2
Reputation: 22031
I think that the simplest solution to compute a metric only on the validation is using a custom callback.
here we define our dummy callback:
class MyCustomMetricCallback(tf.keras.callbacks.Callback):
def __init__(self, train=None, validation=None):
super(MyCustomMetricCallback, self).__init__()
self.train = train
self.validation = validation
def on_epoch_end(self, epoch, logs={}):
mse = tf.keras.losses.mean_squared_error
if self.train:
logs['my_metric_train'] = float('inf')
X_train, y_train = self.train[0], self.train[1]
y_pred = self.model.predict(X_train)
score = mse(y_train, y_pred)
logs['my_metric_train'] = np.round(score, 5)
if self.validation:
logs['my_metric_val'] = float('inf')
X_valid, y_valid = self.validation[0], self.validation[1]
y_pred = self.model.predict(X_valid)
val_score = mse(y_pred, y_valid)
logs['my_metric_val'] = np.round(val_score, 5)
Given this dummy model:
def build_model():
inp1 = Input((5,))
inp2 = Input((5,))
out = Concatenate()([inp1, inp2])
out = Dense(1)(out)
model = Model([inp1, inp2], out)
model.compile(loss='mse', optimizer='adam')
return model
and this data:
X_train1 = np.random.uniform(0,1, (100,5))
X_train2 = np.random.uniform(0,1, (100,5))
y_train = np.random.uniform(0,1, (100,1))
X_val1 = np.random.uniform(0,1, (100,5))
X_val2 = np.random.uniform(0,1, (100,5))
y_val = np.random.uniform(0,1, (100,1))
you can use the custom callback to compute the metric both on train and validation:
model = build_model()
model.fit([X_train1, X_train2], y_train, epochs=10,
callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train), validation=([X_val1, X_val2],y_val))])
only on validation:
model = build_model()
model.fit([X_train1, X_train2], y_train, epochs=10,
callbacks=[MyCustomMetricCallback(validation=([X_val1, X_val2],y_val))])
only on train:
model = build_model()
model.fit([X_train1, X_train2], y_train, epochs=10,
callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train))])
remember only that the callback evaluates the metrics one-shot on the data, like any metric/loss computed by default by keras on the validation_data
.
here is the running code.
Upvotes: 3
Reputation: 86610
I was able to use learning_phase
but only in symbolic tensor mode (graph) mode:
So, at first we need to disable eager mode (this must be done right after importing tensorflow):
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
Then you can create your metric using a symbolic if (backend.switch
):
def metric_graph(in1, in2, out):
actual_metric = out * (in1 + in2)
return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric)
The method add_metric
will ask for a name and an aggregation method, which you can set to "mean"
.
So, here is one example:
x1 = numpy.ones((5,3))
x2 = numpy.ones((5,3))
y = 3*numpy.ones((5,1))
vx1 = numpy.ones((5,3))
vx2 = numpy.ones((5,3))
vy = 3*numpy.ones((5,1))
def metric_eager(in1, in2, out):
if (K.learning_phase()):
return 0
else:
return out * (in1 + in2)
def metric_graph(in1, in2, out):
actual_metric = out * (in1 + in2)
return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric)
ins1 = Input((3,))
ins2 = Input((3,))
outs = Concatenate()([ins1, ins2])
outs = Dense(1)(outs)
model = Model([ins1, ins2],outs)
model.add_metric(metric_graph(ins1, ins2, outs), name='my_metric', aggregation='mean')
model.compile(loss='mse', optimizer='adam')
model.fit([x1, x2],y, validation_data=([vx1, vx2], vy), epochs=3)
Upvotes: 3