Reputation: 1704
I am experimenting with TensorFlow. I've just posted a question regarding an issue it am facing with it. However I also have a perhaps more theoretically question but with practical consequences.
When training the models I find that the accuracy may vary. So, it may happen that the last epoch does not shows up the best accuracy. For instance, on the epoch N I may have an accuracy of 85% whereas on the last epoch the accuracy is 65%. I would like to predict using the weights on the N epoch.
I was wondering whether there is a way of remember the weights values of the epoch with the best accuracy for using later?
The very first and simple approach wold be:
Is there a better one?
Upvotes: 3
Views: 1190
Reputation: 12908
Yes! You need to make a saver and save your session periodically through your training process. The pseudo-code implementation looks like:
model = my_model()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
for epoch in range(NUM_EPOCHS):
for batch in range(NUM_BATCHES):
# ... train your model ...
if batch % VALIDATION_FREQUENCY == 0:
# Periodically test against a validation set.
error = sess.run(model.error, feed_dict=valid_dict)
if error < min_error:
min_error = error # store your best error so far
saver.save(sess, MODEL_PATH) # save the best-performing network so far
Then when you want to test your model against your best-performing iteration:
saver.restore(sess, MODEL_PATH)
test_error = sess.run(model.error, feed_dict=test_dict)
Check out this tutorial on saving and loading metagraphs as well. I found the loading step to be a bit tricky depending on your use case.
Upvotes: 2