codedturtle
codedturtle

Reputation: 65

Updating internal state for TensorFlow custom metrics (aka using non-update_state vars in metric calculation)

Versions: python 3.8.2 (I've also tried on 3.6.8, but I don't think the python version matters here), tensorflow 2.3.0, numpy 1.18.5

I'm training a model for a classification problem with a sparse labels tensor. How would I go about defining a metric that counts the number of times that the "0" label has appeared up until that point? What I'm trying to do in the code example below is to store all the labels that the metric has seen in an array and constantly concatenate the existing array with the new y_true every time update_state is called. (I know I could just store a count variable and use +=, but in the actual usage scenario, concatenating is ideal and memory is not an issue.) Here's minimal code to reproduce the problem:

import tensorflow as tf

class ZeroLabels(tf.keras.metrics.Metric):
    """Accumulates a list of all y_true sparse categorical labels (ints) and calculates the number of times the '0' label has appeared."""
    def __init__(self, *args, **kwargs):
        super(ZeroLabels, self).__init__(name="ZeroLabels")
        self.labels = self.add_weight(name="labels", shape=(), initializer="zeros", dtype=tf.int32)

    def update_state(self, y_true, y_pred, sample_weight=None):
        """I'm using sparse categorical crossentropy, so labels are 1D array of integers."""
        if self.labels.shape == (): # if this is the first time update_state is being called
            self.labels = y_true
        else:
            self.labels = tf.concat((self.labels, y_true), axis=0)

    def result(self):
        return tf.reduce_sum(tf.cast(self.labels == 0, dtype=tf.int32))

    def reset_states(self):
        self.labels = tf.constant(0, dtype=tf.int32)

This code works on its own, but it throws the following error when I try to train a model using this metric:

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2

I thought this might have something to do with the fact that self.labels isn't directly part of the graph when update_state is called. Here are some other things I've tried:

If it helps, here's a more general version of this question: How can we incorporate tensors that aren't passed in as parameters to update_state in the update_state calculation? Any help would be greatly appreciated. Thank you in advance!

Upvotes: 0

Views: 1454

Answers (1)

Miguel Herrera Ruiz
Miguel Herrera Ruiz

Reputation: 134

The main problem was the first iteration assignment, when there is not an initial value:

if self.labels.shape == ():
    self.labels = y_true
else:
    self.labels = tf.concat((self.labels, y_true), axis=0)

Inside the if block, your variable 'labels' defined in the constructor just disappears and is replaced by a tf.Tensor object (y_true). So, you have to use tf.Variable methods (assign, add_assing) to modify its content but keeping the object. Moreover, to be able to change a tf.variable shape, you have to create it in such a way that it will allow you to have an undefined shape, in this case: (None,1), because you're concatenating on axis=0.

So:

class ZeroLabels(tf.keras.metrics.Metric):
    def __init__(self, *args, **kwargs):
        super(ZeroLabels, self).__init__(name="ZeroLabels")

        # Define a variable with unknown shape. This will allow you have dynamically sized variables (validate_shape=False)
        self.labels = tf.Variable([], shape=(None,), validate_shape=False)

    def update_state(self, y_true, y_pred, sample_weight=None):
        # On update method, just assign as new value the prevoius one joined with y_true
        self.labels.assign(tf.concat([self.labels.value(), y_true[:,0]], axis=0))

    def result(self):
        return tf.reduce_sum(tf.cast(self.labels.value() == 0, dtype=tf.int32))

    def reset_states(self):
        # To reset the metric, assign again an empty tensor
        self.labels.assign([])

But, if you only one to count the 0s of the dataset, I suggest you to have an integer variable which will count these elements, because after every batch proccessed, labels array will increase its size and getting the sum of all its elements will take more and more time, slowing down your training.

class ZeroLabels_2(tf.keras.metrics.Metric):
    """Accumulates a list of all y_true sparse categorical labels (ints) and calculates the number of times the '0' label has appeared."""
    def __init__(self, *args, **kwargs):
        super(ZeroLabels_2, self).__init__(name="ZeroLabels")

        # Define an integer variable
        self.labels = tf.Variable(0, dtype=tf.int32)

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Increase variable with every batch
        self.labels.assign_add(tf.cast(tf.reduce_sum(tf.cast(y_true == 0, dtype=tf.int32)), dtype=tf.int32 ))

    def result(self):
        # Simply return variable's content
        return self.labels.value()

    def reset_states(self):
        self.labels.assign(0)

I hope this can help you (and apologies for English level)

Upvotes: 7

Related Questions