Natjo
Natjo

Reputation: 2118

Reset local variables of metrics after each epoch

I use the built-in method tf.metrics.precision to evaluate my model. I was looking through its definition, but the local variables are never reset.

Shouldn't they be reset after each epoch in order to remove the counts from the last epochs? Is this done automatically and I was just overlooking it in the source code, or am I supposed to do it? If the latter is true, how do I reset the local variables? I didn't read anything about it in the documentation.

Upvotes: 0

Views: 1436

Answers (2)

javidcf
javidcf

Reputation: 59731

Variables for keeping track of metrics are created with the metric_variable function and thus added to the collection with key tf.GraphKeys.METRIC_VARIABLES. After you have defined all your metrics, you can have a reset operation like this:

reset_metrics_op = tf.variables_initializer(tf.get_collection(tf.GraphKeys.METRIC_VARIABLES))

And run it after each epoch is finished.

Upvotes: 1

ARAT
ARAT

Reputation: 963

Yes. One must be careful how to reset variables when processing the data in batches. Arranging the operations while calculating overall metrics (i.e., precision, accuracy or auc) and batch metrics is different. One needs to reset the running variables to zero before calculating accuracy values of each new batch of data.

With tf.metrics.precision , two running variables are created and placed into the computational graph: true_positives and false_positives. So, you can choose which variables to reset using scope argument of tf.get_collection().

import tensorflow as tf
import numpy as np

import numpy as np
import tensorflow as tf

labels = np.array([[1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0]], dtype=np.uint8)

predictions = np.array([[1,0,0,0],
                        [1,1,0,0],
                        [1,1,1,0],
                        [0,1,1,1]], dtype=np.uint8)

precision, update_op = tf.metrics.precision(labels, predictions, name = 'precision')

print(precision)
#Tensor("precision/value:0", shape=(), dtype=float32)
print(update_op)
#Tensor("precision/update_op:0", shape=(), dtype=float32)

tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES)
#[<tf.Variable 'precision/true_positives/count:0' shape=() dtype=float32_ref>,
# <tf.Variable 'precision/false_positives/count:0' shape=() dtype=float32_ref>,

running_vars_precision = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope='precision')
running_vars_auc_initializer = tf.variables_initializer(var_list=running_vars_precision )

with tf.Session() as sess:
    sess.run(running_vars_auc_initializer)
    print("tf precision/update_op: {}".format(sess.run([precision, update_op])))
    #tf precision/update_op: [0.8888889, 0.8888889]
    print("tf precision: {}".format(sess.run(precision)))
    #tf precision: 0.8888888955116272

Upvotes: 0

Related Questions