Miguel Monteiro
Miguel Monteiro

Reputation: 379

Restoring Tensorflow model from .pbtxt and .meta files

I trained a model using MonitoredTrainingSession() with a checkpoint saver hook tf.train.CheckpointSaverHook() saving checkpoints every 1000 steps. After training the following files were created in the checkpoint directory:

events.out.tfevents.1511969396.cmle-training-master-ef2237c814-0-xn7pp
graph.pbtxt
model.ckpt-1.meta
model.ckpt-1001.meta
model.ckpt-2001.meta
model.ckpt-3001.meta
model.ckpt-4001.meta
model.ckpt-4119.meta

I want to restore the checkpoint but can't, here is my code (assuming the files above are in the directory checkpoints):

tf.train.import_meta_graph('checkpoints/model.ckpt-4139.meta')
saver = tf.train.Saver()
with tf.Session() as sess:

    ckpt = tf.train.get_checkpoint_state('./checkpoints/')
    saver.restore(sess, ckpt.model_checkpoint_path)

The problem is ckpt is None, I think I might be missing a file... What I am doing wrong.

This is how I save the checkpoints:

hooks=lists()
hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir=checkpoint_dir, save_steps=checkpoint_iterations)

with tf.Graph().as_default():
    with tf.device(tf.train.replica_device_setter()):

        batch = model.input_fn(train_path, batch_size, epochs, 'train_queue')

        tensors = model.model_fn(batch, content_weight, style_weight, tv_weight, vgg_path, style_features,
                                 batch_size, learning_rate)

    with tf.train.MonitoredTrainingSession(master=target,
                                           is_chief=is_chief,
                                           checkpoint_dir=job_dir,
                                           hooks=hooks,
                                           save_checkpoint_secs=None,
                                           save_summaries_steps=None,
                                           log_step_count_steps=10) as sess:
        _ = sess.run(tensors)
       (...)

Upvotes: 0

Views: 1794

Answers (1)

GPhilo
GPhilo

Reputation: 19123

Restoring the full checkpoint

tf.train.get_checkpoint_state checks the checkpoint (no extension) file inside the directory you pass as parameter. This file has usually a content similar to:

model_checkpoint_path: "model.ckpt-1"
all_model_checkpoint_paths: "model.ckpt-1"

If this file is missing, the function will return None. Add a text file with that name and content to your model folder and you'll be able to restore using the code you already have.

Very important note: To restore this way you need all the checkpoint data, i.e., the three files: .data-*, .meta and .index.

Restoring just the graph

If, however, you're interested in restoring only the meta-graph, you can do so via import_meta_graph() as detailed in the official TF guide.

Note (from the definition of import_meta_graph()):

This function takes a MetaGraphDef protocol buffer as input. If the argument is a file containing a MetaGraphDef protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the graph_def field to the current graph, recreates all the collections, and returns a saver constructed from the saver_def field.

Using that saver won't work unless you have the .index and .data-* files in the same directory.

Upvotes: 2

Related Questions