Ron Cohen
Ron Cohen

Reputation: 2925

Save and load Tensorflow model

I want to save a Tensorflow (0.12.0) model, including graph and variable values, then later load and execute it. I have the read the docs and other posts on this but cannot get the basics to work. I am using the technique from this page in the Tensorflow docs. Code:

Save a simple model:

myVar = tf.Variable(7.1)
tf.add_to_collection('modelVariables', myVar) # why?
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    print sess.run(myVar)
    saver0 = tf.train.Saver()
    saver0.save(sess, './myModel.ckpt')
    saver0.export_meta_graph('./myModel.meta')

Later, load and execute the model:

with tf.Session() as sess:
    saver1 = tf.train.import_meta_graph('./myModel.meta')
    saver1.restore(sess, './myModel.meta')
    print sess.run(myVar)

Question 1: The saving code seems to work but the loading code produces this error:

W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open ./myModel.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?

How to fix this?.

Question 2: I included this line to follow the pattern in the TF docs...

tf.add_to_collection('modelVariables', myVar)

... but why is that line necessary? Doesn't expert_meta_graphexport the entire graph by default? If not then does one need to add every variable in the graph to the collection before saving? Or do we just add to the collection those variables that will be accessed after the restore?

---------------------- Update January 12 2017 -----------------------------

Partial success based on Kashyap's suggestion below but a mystery still exists. The code below works but only if I include the lines containing tf.add_to_collection and tf.get_collection. Without those lines, 'load' mode throws an error in the last line: NameError: name 'myVar' is not defined. My understanding was that by default Saver.save saves and restores all variables in the graph, so why is it necessary to specify the name of variables that will be used in the collection? I assume this has to do with mapping Tensorflow's variable names to Python names, but what are the rules of the game here? For which variables does this need to be done?

mode = 'load' # or 'save'
if mode == 'save':
    myVar = tf.Variable(7.1)
    init_op = tf.global_variables_initializer()
    saver0 = tf.train.Saver()
    tf.add_to_collection('myVar', myVar) ### WHY NECESSARY?
    with tf.Session() as sess:
        sess.run(init_op)
        print sess.run(myVar)
        saver0.save(sess, './myModel')
if mode == 'load':
    with tf.Session() as sess:
        saver1 = tf.train.import_meta_graph('./myModel.meta')
        saver1.restore(sess, tf.train.latest_checkpoint('./'))
        myVar = tf.get_collection('myVar')[0]  ### WHY NECESSARY?
        print sess.run(myVar)

Upvotes: 1

Views: 6728

Answers (2)

Achilles
Achilles

Reputation: 1129

I've been trying to figure out the same thing and was able to successfully do it by using Supervisor. It automatically loads all variables and your graph etc. Here is the documentation - https://www.tensorflow.org/programmers_guide/supervisor. Below is my code -

sv = tf.train.Supervisor(logdir="/checkpoint', save_model_secs=60)
    with sv.managed_session() as sess:
        if not sv.should_stop(): 
            #Do run/eval/train ops on sess as needed. Above works for both saving and loading

As you see, this is much simpler than using the Saver object and dealing with individual variables etc as long as the graph stays the same (my understanding is that Saver comes handy when we want to reuse a pre-trained model for a different graph).

Upvotes: 1

Kashyap
Kashyap

Reputation: 6689

Question1

This question has been already answered thoroughly here. You don't have to explicitly call export_meta_graph. Call the save method. This will generate the .meta file also (since save method will call the export_meta_graph method internally.)

For example

saver0.save(sess, './myModel.ckpt')

will produce myModel.ckpt file and also the myModel.ckpt.meta file.

Then you can restore the model using

with tf.Session() as sess:
    saver1 = tf.train.import_meta_graph('./myModel.ckpt.meta')
    saver1.restore(sess, './myModel')
    print sess.run(myVar)

Question2

Collections are used to store custom information like learning rate,the regularisation factor that you have used and other information and these will be stored when you export the graph. Tensorflow itself defines some collections like "TRAINABLE_VARIABLES" which are used to get all the trainable variables of the model you built. You can chose to export all the collections in your graph or you can specify which collections to export in the export_meta_graph function.

Yes tensorflow will export all the variables that you define. But if you need any other information that needs to be exported to the graph then they can be added to the collection.

Upvotes: 2

Related Questions