Add a summary of accuracy of the whole train/test dataset in Tensorflow

I am trying to use Tensorboard to visualize my training procedure. My purpose is, when every epoch completed, I would like to test the network's accuracy using the whole validation dataset, and store this accuracy result into a summary file, so that I can visualize it in Tensorboard.

I know Tensorflow has summary_op to do it, however it seems only work for one batch when running the code sess.run(summary_op). I need to calculate the accuracy for the whole dataset. How?

Is there any example to do it?

Answers (3)


Using batching with your validation set is possible in case you are using tf.metrics ops, which use internal counters. Here is a simplified example:

model = create_model()
tf.summary.scalar('cost', model.cost_op)
acc_value_op, acc_update_op = tf.metrics.accuracy(labels,predictions)

summary_common = tf.summary.merge_all()

summary_valid = tf.summary.merge([
    tf.summary.scalar('accuracy', acc_value_op),
    # other metrics here...

with tf.Session() as sess:
    train_writer = tf.summary.FileWriter(logs_path + '/train',
    valid_writer = tf.summary.FileWriter(logs_path + '/valid')

While training, only write the common summary using your train-writer:

summary = sess.run(summary_common)
train_writer.add_summary(summary, tf.train.global_step(sess, gstep_op))

After every validation, write both summaries using the valid-writer:

gstep, summaryc, summaryv = sess.run([gstep_op, summary_common, summary_valid])
valid_writer.add_summary(summaryc, gstep)
valid_writer.add_summary(summaryv, gstep)

When using tf.metrics, don't forget to reset the internal counters (local variables) before every validation step.

I implement a naive one-layer model as an example to classify MNIST dataset and visualize validation accuracy in Tensorboard, it works for me.

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
import os

# number of epoch
num_epoch = 1000
model_dir = '/tmp/tf/onelayer_model/accu_info'
# mnist dataset location, change if you need
data_dir = '../data/mnist'

# load MNIST dataset without one hot
dataset = read_data_sets(data_dir, one_hot=False)

# Create placeholder for input images X and labels y
X = tf.placeholder(tf.float32, [None, 784])
# one_hot = False
y = tf.placeholder(tf.int32)

# One layer model graph
W = tf.Variable(tf.truncated_normal([784, 10], stddev=0.1))
b = tf.Variable(tf.constant(0.1, shape=[10]))
logits = tf.nn.relu(tf.matmul(X, W) + b)

init = tf.initialize_all_variables()

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y)
# loss function
loss = tf.reduce_mean(cross_entropy)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

_, top_1_op = tf.nn.top_k(logits)
top_1 = tf.reshape(top_1_op, shape=[-1])
correct_classification = tf.cast(tf.equal(top_1, y), tf.float32)
# accuracy function
acc = tf.reduce_mean(correct_classification)

# define info that is used in SummaryWritter
acc_summary = tf.scalar_summary('valid_accuracy', acc)
valid_summary_op = tf.merge_summary([acc_summary])

with tf.Session() as sess:
    # initialize all the variable

    print("Writing Summaries to %s" % model_dir)
    train_summary_writer = tf.train.SummaryWriter(model_dir, sess.graph)

    # load validation dataset
    valid_x = dataset.validation.images
    valid_y = dataset.validation.labels

    for epoch in xrange(num_epoch):
        batch_x, batch_y = dataset.train.next_batch(100)
        feed_dict = {X: batch_x, y: batch_y}
        _, acc_value, loss_value = sess.run(
            [train_op, acc, loss], feed_dict=feed_dict)
        vsummary = sess.run(valid_summary_op,
                            feed_dict={X: valid_x,
                                       y: valid_y})

        # Write validation accuracy summary
        train_summary_writer.add_summary(vsummary, epoch)

Define a tf.scalar_summary that accepts a placeholder:

accuracy_value_ = tf.placeholder(tf.float32, shape=())
accuracy_summary = tf.scalar_summary('accuracy', accuracy_value_)

Then calculate the accuracy for the whole dataset (define a routine that calculates the accuracy for every batch in the dataset and extract the mean value) and save it into a python variable, let's call it va.

Once you have the value of va, just run the accuracy_summary op, feeding the accuracy_value_ placeholder:

sess.run(accuracy_summary, feed_dict={accuracy_value_: va})

