pafede2
pafede2

Reputation: 1704

TensorFlow: remember weigth of previous epochs

I am experimenting with TensorFlow. I've just posted a question regarding an issue it am facing with it. However I also have a perhaps more theoretically question but with practical consequences.

When training the models I find that the accuracy may vary. So, it may happen that the last epoch does not shows up the best accuracy. For instance, on the epoch N I may have an accuracy of 85% whereas on the last epoch the accuracy is 65%. I would like to predict using the weights on the N epoch.

I was wondering whether there is a way of remember the weights values of the epoch with the best accuracy for using later?

The very first and simple approach wold be:

  1. Run N epochs
  2. Rememberer the best accuracy
  3. Re-start the training until we reach an epoch that shows the same accuracy than the one stored on step 2.
  4. Predict using the current weigths

Is there a better one?

Upvotes: 3

Views: 1190

Answers (1)

Engineero
Engineero

Reputation: 12908

Yes! You need to make a saver and save your session periodically through your training process. The pseudo-code implementation looks like:

model = my_model()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init_op)
    for epoch in range(NUM_EPOCHS):
        for batch in range(NUM_BATCHES):

            # ... train your model ...

            if batch % VALIDATION_FREQUENCY == 0:
                # Periodically test against a validation set.
                error = sess.run(model.error, feed_dict=valid_dict)
                if error < min_error:
                    min_error = error  # store your best error so far
                    saver.save(sess, MODEL_PATH)  # save the best-performing network so far

Then when you want to test your model against your best-performing iteration:

saver.restore(sess, MODEL_PATH)
test_error = sess.run(model.error, feed_dict=test_dict)

Check out this tutorial on saving and loading metagraphs as well. I found the loading step to be a bit tricky depending on your use case.

Upvotes: 2

Related Questions