Reputation: 7099
I'm using tf.keras
and I have a metric that I'd like to calculate where I need multiple batches of validation data in order to calculate it reliably. Is there some way to accumulate batches before calculating the metric?
I'd like to do something like this:
class MultibatchMetric(tf.keras.metrics.Metric):
def __init__(self, num_batches, name="sdr_metric", **kwargs):
super().__init__(name=name, **kwargs)
self.num_batches = num_batches
self.batch_accumulator = []
self.my_metric = []
def update_state(self, y_true, y_pred, sample_weight=None):
self.batch_accumulator.append((y_true, y_pred))
if len(self.batch_accumulator) >= self.num_batches:
metric = custom_multibatch_metric_func(self.batch_accumulator)
self.my_metric.append(metric)
self.batch_accumulator = []
def result(self):
return mean(self.my_metric)
def reset_states(self):
self.my_metric = []
self.batch_accumulator = []
However, this all needs to occur on the tensorflow graph, severely complicating things.
Upvotes: 1
Views: 1207
Reputation: 151
I had a go at your problem and it seems using the built in add_weight
method can provide a solution. By making a state variable for a batch counter and an accumulator that has the size (2, num_batches * batch_size, n_outputs)
. Each update a batch gets stored by adding a padded batch to the state variable and gets reset when the counter reaches the maximum number of batches. You can then get the result from the accumulator by calling your metric on the accumulator state variable. I have added an example below.
class Metric(tf.keras.metrics.Metric):
def __init__(self, num_batches, batch_size, name="sdr_metric", **kwargs):
super().__init__(name=name, **kwargs)
self.num_batches = num_batches
self.batch_size = batch_size
self.batch_accumulator = self.add_weight(name='accumulator', shape=(2, num_batches * batch_size, 2), initializer='zeros')
self.batch_counter = self.add_weight(name='counter', shape=(), initializer='zeros')
@tf.function
def update_state(self, y_true, y_pred, sample_weight=None):
batch_count = self.batch_counter
batch = tf.stack([tf.cast(y_true, tf.float32), tf.cast(y_pred, tf.float32)])
paddings = [[0, 0], [batch_count * self.batch_size, (self.num_batches - batch_count - 1) * self.batch_size], [0, 0]]
padded_batch = tf.pad(batch, paddings)
self.batch_accumulator.assign_add(padded_batch)
self.batch_counter.assign_add(1)
if batch_count == self.num_batches:
self.reset_states()
@tf.function
def result(self):
if self.batch_counter == self.num_batches - 1:
return custom_multibatch_metric_func(self.batch_accumulator)
else:
return 0.
def reset_states(self):
self.batch_counter.assign(0)
self.batch_accumulator.assign(tf.zeros((2, self.num_batches * self.batch_size, 2)))
And the test problem i used to verify.
# data
n = 1028
batch_size = 32
num_batches = 3
f = 4
lr = 10e-3
x = tf.random.uniform((n, f), -1, 1)
y = tf.concat([tf.reduce_sum(x, axis=-1, keepdims=True), tf.reduce_mean(x, axis=-1, keepdims=True)], axis=-1)
ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(b, drop_remainder=True)
model = tf.keras.models.Sequential([Dense(f, activation='relu'), Dense(2)])
model.compile(tf.keras.optimizers.SGD(lr), tf.keras.losses.mean_squared_error, metrics=Metric(num_batches, batch_size))
model.fit(ds, epochs=10)
Two large issues. Firstly the if statement in result
call, but depending on what you require of the resulting metric you can return an idempotent value. Here where I assumed you just sum all the results so 0
has no effect. Secondly this approach requires you to drop the remainder unless you dataset size is divisible by your batch size.
I hope this was helpful eventhough this is not an optimal solution by any means.
Upvotes: 2