Reputation: 5968
I want to use tf.metrics.accuracy
to track the accuracy of my predictions, but I am unsure of how to use the update_op (acc_update_op
below) that the function returns:
accuracy, acc_update_op = tf.metrics.accuracy(labels, predictions)
I was thinking that adding it to tf.GraphKeys.UPDATE_OPS
would make sense, but I am not sure how to do this.
Upvotes: 1
Views: 2679
Reputation: 29972
tf.metrics.accuracy
is one of the many streamed metric TensorFlow operations (another one of which is tf.metrics.recall
). Upon creation, two variables (count
and total
) are created in order to accumulate all incoming results for one final outcome. The first returned value is a tensor for the calculation count / total
. The second op returned is a stateful function which updates these variables. Streamed metric functions are useful when evaluating the performance of a classifier over multiple batches of data. A quick example of use:
# building phase
with tf.name_scope("streaming"):
accuracy, acc_update_op = tf.metrics.accuracy(labels, predictions)
test_fetches = {
'accuracy': accuracy,
'acc_op': acc_update_op
}
# when testing the classifier
with tf.name_scope("streaming"):
# clear counters for a fresh evaluation
sess.run(tf.local_variables_initializer())
for _i in range(n_batches_in_test):
fd = get_test_batch()
outputs = sess.run(test_fetches, feed_dict=fd)
print("Accuracy:", outputs['accuracy'])
I was thinking that adding it to
tf.GraphKeys.UPDATE_OPS
would make sense, but I am not sure how to do this.
That would not be a good idea unless you are only using the UPDATE_OPS collection for testing purposes. Usually, the collection will already have certain control operations for the training phase (such as moving batch normalization parameters) that are not meant to be run alongside the validation phase. It may be best to either keep them in a new collection or add these operations to the fetch dictionary manually.
Upvotes: 7