han.liu
han.liu

Reputation: 83

How to load and test a specific TensorFlow saved model?

After training, I got many saved models. For example, in the saved model folder, I have 3 saved models and a checkpoint file named:

checkpoint,  
model.ckpt-1000.data-00000-of-00001,  
model.ckpt-1000.index,  
model.ckpt-1000.meta,  
model.ckpt-2000.data-00000-of-00001,  
model.ckpt-2000.index,  
model.ckpt-2000.meta,  
model.ckpt-3000.data-00000-of-00001,  
model.ckpt-3000.index,  
model.ckpt-3000.meta,

I have tried 2 different ways:

First:

ckpt = tf.train.latest_checkpoint(CHECKPOINT_DIR)
saver.restore(sess, ckpt)

Second:

ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
saver.restore(sess, ckpt.model_checkpoint_path)

They all worked! But they can only test the newest model.

If I want to test a specific model, I have to modify model_checkpoint_path: "model.ckpt-3000" to model_checkpoint_path: "model.ckpt-2000" in the checkpoint file.

My question is how to test all models one by one? (Or, How to test a specific model?)

Upvotes: 1

Views: 1892

Answers (1)

Amir
Amir

Reputation: 16587

You can restore a specific checkpoint with checkpoint.restore method. In addition to file-name, it is necessary to specify index. For example, assume you want to load checkpoint at iteration 1000, then you write:

status = ckpnt.restore('./test/model.ckpt-1000')

Another time you need to load checkpoint at iteration 2000:

status = ckpnt.restore('./test/model.ckpt-2000')

Complete example:

import tensorflow as tf

v1 = tf.Variable(9., name="v1")
v2 = tf.Variable(2., name="v2")
a = tf.add(v1, v2)

ckpnt = tf.train.Checkpoint(firstVar=v1, secondVar=v2)

with tf.Session() as sess:
    # Init v1 and v2
    sess.run(tf.global_variables_initializer())
    # Print value of v1
    print(sess.run(v1))
    # Save v1 and v2 variables
    ckpnt.save('./test/myVar', sess)
    sess.run(v1.assign(90))
    sess.run(v2.assign(20))
    ckpnt.save('./test/myVar', sess)

ckpnt = tf.train.Checkpoint(firstVar=v1, secondVar=v2)
status = ckpnt.restore('./test/myVar-1')
status.assert_consumed()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    status.initialize_or_restore(sess)
    print(sess.run(v1))

Upvotes: 2

Related Questions