Ofri Mann
Ofri Mann

Reputation: 381

Getting an error when using tf.keras.metrics.Mean in functional Keras API

I'm trying to add a Mean metric to a Keras functional model (Tensorflow 2.5), and am getting the following error:

ValueError: Expected a symbolic Tensor for the metric value, received: tf.Tensor(0.0, shape=(), dtype=float32)

Here is the code:

x = [1, 2, 3, 4, 5, 6, 7, 8]
y = [5 + i * 3 for i in x]
a = Input(shape=(1,))
output = Dense(1)(a)
model = Model(inputs=a,outputs=output)
model.add_metric(tf.keras.metrics.Mean()(output))
model.compile(loss='mse')
model.fit(x=x, y=y, epochs=100)

If I remove the following line (from which the exception is thrown):

model.add_metric(tf.keras.metrics.Mean()(output))

the code works as expected.

I Tried disabling eager execution, but I get the following error instead:

ValueError: Using the result of calling a `Metric` object when calling `add_metric` on a Functional Model is not supported. Please pass the Tensor to monitor directly.

The above usage was pretty much copied from the tf.keras.metrics.Mean documentation (see Usage with compile() API)

Upvotes: 1

Views: 1630

Answers (1)

Ofri Mann
Ofri Mann

Reputation: 381

I found a way to bypass the problem by avoiding usage of model.add_metric altogether, and passing a Metric object to the compile() method.
However, when passing an instance of tf.keras.metrics.Mean as follows:

model.compile(loss='mse', metrics=tf.keras.metrics.Mean())

I get the following error from the compile() method:

TypeError: update_state() got multiple values for argument 'sample_weight'

To solve this, I had to extend tf.keras.metrics.Mean and change the signature of update_state to match the expected signature.
Here is the final (working) code:

class FixedMean(tf.keras.metrics.Mean):
    def update_state(self, y_true, y_pred, sample_weight=None):
        super().update_state(y_pred, sample_weight=sample_weight)

x = [1, 2, 3, 4, 5, 6, 7, 8]
y = [5 + i * 3 for i in x]
a = Input(shape=(1,))
output = Dense(1)(a)
model = Model(inputs=a,outputs=output)
model.compile(loss='mse', metrics=FixedMean())
model.fit(x=x, y=y, epochs=100)

Upvotes: 2

Related Questions