Luke
Luke

Reputation: 7099

Calculate tensorflow Metric using more than one batch at a time

I'm using tf.keras and I have a metric that I'd like to calculate where I need multiple batches of validation data in order to calculate it reliably. Is there some way to accumulate batches before calculating the metric?

I'd like to do something like this:

class MultibatchMetric(tf.keras.metrics.Metric):
    def __init__(self, num_batches, name="sdr_metric", **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_batches = num_batches
        self.batch_accumulator = []
        self.my_metric = []

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.batch_accumulator.append((y_true, y_pred))
        if len(self.batch_accumulator) >= self.num_batches:
            metric = custom_multibatch_metric_func(self.batch_accumulator)
            self.my_metric.append(metric)
            self.batch_accumulator = []

    def result(self):
        return mean(self.my_metric)

    def reset_states(self):
        self.my_metric = []
        self.batch_accumulator = []

However, this all needs to occur on the tensorflow graph, severely complicating things.

Upvotes: 1

Views: 1207

Answers (1)

I had a go at your problem and it seems using the built in add_weight method can provide a solution. By making a state variable for a batch counter and an accumulator that has the size (2, num_batches * batch_size, n_outputs). Each update a batch gets stored by adding a padded batch to the state variable and gets reset when the counter reaches the maximum number of batches. You can then get the result from the accumulator by calling your metric on the accumulator state variable. I have added an example below.

class Metric(tf.keras.metrics.Metric):
    def __init__(self, num_batches, batch_size, name="sdr_metric", **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_batches = num_batches
        self.batch_size = batch_size
        self.batch_accumulator = self.add_weight(name='accumulator', shape=(2, num_batches * batch_size, 2), initializer='zeros')
        self.batch_counter = self.add_weight(name='counter', shape=(), initializer='zeros')
        
    @tf.function
    def update_state(self, y_true, y_pred, sample_weight=None):
        batch_count = self.batch_counter
        batch = tf.stack([tf.cast(y_true, tf.float32), tf.cast(y_pred, tf.float32)])
        paddings = [[0, 0], [batch_count * self.batch_size, (self.num_batches - batch_count - 1) * self.batch_size], [0, 0]]
        padded_batch = tf.pad(batch, paddings)
        self.batch_accumulator.assign_add(padded_batch)            
        self.batch_counter.assign_add(1)
        if batch_count == self.num_batches:
            self.reset_states()
        
    @tf.function
    def result(self):
        if self.batch_counter == self.num_batches - 1:
            return custom_multibatch_metric_func(self.batch_accumulator)
        else:
            return 0.

    def reset_states(self):
        self.batch_counter.assign(0)
        self.batch_accumulator.assign(tf.zeros((2, self.num_batches * self.batch_size, 2)))

And the test problem i used to verify.

# data
n = 1028
batch_size = 32
num_batches = 3
f = 4
lr = 10e-3

x = tf.random.uniform((n, f), -1, 1)
y = tf.concat([tf.reduce_sum(x, axis=-1, keepdims=True), tf.reduce_mean(x, axis=-1, keepdims=True)], axis=-1)

ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(b, drop_remainder=True)

model = tf.keras.models.Sequential([Dense(f, activation='relu'), Dense(2)])
model.compile(tf.keras.optimizers.SGD(lr), tf.keras.losses.mean_squared_error, metrics=Metric(num_batches, batch_size))
model.fit(ds, epochs=10)

Two large issues. Firstly the if statement in result call, but depending on what you require of the resulting metric you can return an idempotent value. Here where I assumed you just sum all the results so 0 has no effect. Secondly this approach requires you to drop the remainder unless you dataset size is divisible by your batch size.

I hope this was helpful eventhough this is not an optimal solution by any means.

Upvotes: 2

Related Questions