Reputation: 2118
I use the built-in method tf.metrics.precision to evaluate my model. I was looking through its definition, but the local variables are never reset.
Shouldn't they be reset after each epoch in order to remove the counts from the last epochs? Is this done automatically and I was just overlooking it in the source code, or am I supposed to do it? If the latter is true, how do I reset the local variables? I didn't read anything about it in the documentation.
Upvotes: 0
Views: 1436
Reputation: 59731
Variables for keeping track of metrics are created with the metric_variable
function and thus added to the collection with key tf.GraphKeys.METRIC_VARIABLES
. After you have defined all your metrics, you can have a reset operation like this:
reset_metrics_op = tf.variables_initializer(tf.get_collection(tf.GraphKeys.METRIC_VARIABLES))
And run it after each epoch is finished.
Upvotes: 1
Reputation: 963
Yes. One must be careful how to reset variables when processing the data in batches. Arranging the operations while calculating overall metrics (i.e., precision, accuracy or auc) and batch metrics is different. One needs to reset the running variables to zero before calculating accuracy values of each new batch of data.
With tf.metrics.precision
, two running variables are created and placed into the computational graph: true_positives
and false_positives
. So, you can choose which variables to reset using scope
argument of tf.get_collection()
.
import tensorflow as tf
import numpy as np
import numpy as np
import tensorflow as tf
labels = np.array([[1,1,1,0],
[1,1,1,0],
[1,1,1,0],
[1,1,1,0]], dtype=np.uint8)
predictions = np.array([[1,0,0,0],
[1,1,0,0],
[1,1,1,0],
[0,1,1,1]], dtype=np.uint8)
precision, update_op = tf.metrics.precision(labels, predictions, name = 'precision')
print(precision)
#Tensor("precision/value:0", shape=(), dtype=float32)
print(update_op)
#Tensor("precision/update_op:0", shape=(), dtype=float32)
tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES)
#[<tf.Variable 'precision/true_positives/count:0' shape=() dtype=float32_ref>,
# <tf.Variable 'precision/false_positives/count:0' shape=() dtype=float32_ref>,
running_vars_precision = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope='precision')
running_vars_auc_initializer = tf.variables_initializer(var_list=running_vars_precision )
with tf.Session() as sess:
sess.run(running_vars_auc_initializer)
print("tf precision/update_op: {}".format(sess.run([precision, update_op])))
#tf precision/update_op: [0.8888889, 0.8888889]
print("tf precision: {}".format(sess.run(precision)))
#tf precision: 0.8888888955116272
Upvotes: 0