matsui Cao
matsui Cao

Reputation: 57

TensorFlow - validate accuracy with batch data

As the tutorials said, after every certain steps, I need to use 'validation' dataset to validate the accuracy of the model by now, and use 'test' dataset to test the accuracy finally.

example code:

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

validate_acc = sess.run(accuracy, feed_dict=validate_feed)

But I consider its too big for my device, OOM maybe occur.

How to feed the method'accuracy' with batch of validate_feed and get the total 'validate_acc'?

(if i make a iterator from dataset, how can i feed next_batch into 'accuracy' method?)

Thank you everyone for help!

Upvotes: 2

Views: 3820

Answers (2)

P-Gn
P-Gn

Reputation: 24581

Use tf.metrics.acccuracy. It makes a streaming computation of the accuracy, meaning that it accumulates all the necessary information for you and return the current estimate of the accuracy when needed.

See this answer for an example of how to use it.

Upvotes: 3

Ufuk Can Bicici
Ufuk Can Bicici

Reputation: 3649

Normally, you use something similar to the following for measuring the accuracy:

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

The logits is the final feature which you usually pass into the softmax - cross entropy layer. The above calculates the accuracy for the given batch, but not on the whole dataset. You can do the following instead:

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
total_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))

Execute "total_correct" for every batch in your test set and accumulate them:

  correct_sum = 0
  for batch in data_set:
       batch_correct_count = sess.run(total_correct, feed_dict=validate_feed)
       correct_sum += batch_correct_count

  total_accuracy = correct_sum / data_set.size()

With the formulation above, you can correctly calculate the overall accuracy, by processing the data with batches. This is, of course, with the assumption that the for loop runs on mutually exclusive batches from the dataset. You should avoid-disable iid sampling or sampling with replacement from the dataset, which is usually done for stochastic training.

Upvotes: 4

Related Questions