Reputation: 2925
I want to save a Tensorflow (0.12.0) model, including graph and variable values, then later load and execute it. I have the read the docs and other posts on this but cannot get the basics to work. I am using the technique from this page in the Tensorflow docs. Code:
Save a simple model:
myVar = tf.Variable(7.1)
tf.add_to_collection('modelVariables', myVar) # why?
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print sess.run(myVar)
saver0 = tf.train.Saver()
saver0.save(sess, './myModel.ckpt')
saver0.export_meta_graph('./myModel.meta')
Later, load and execute the model:
with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.meta')
saver1.restore(sess, './myModel.meta')
print sess.run(myVar)
Question 1: The saving code seems to work but the loading code produces this error:
W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open ./myModel.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
How to fix this?.
Question 2: I included this line to follow the pattern in the TF docs...
tf.add_to_collection('modelVariables', myVar)
... but why is that line necessary? Doesn't expert_meta_graph
export the entire graph by default? If not then does one need to add every variable in the graph to the collection before saving? Or do we just add to the collection those variables that will be accessed after the restore?
---------------------- Update January 12 2017 -----------------------------
Partial success based on Kashyap's suggestion below but a mystery still exists. The code below works but only if I include the lines containing tf.add_to_collection
and tf.get_collection
. Without those lines, 'load' mode throws an error in the last line:
NameError: name 'myVar' is not defined
. My understanding was that by default Saver.save
saves and restores all variables in the graph, so why is it necessary to specify the name of variables that will be used in the collection? I assume this has to do with mapping Tensorflow's variable names to Python names, but what are the rules of the game here? For which variables does this need to be done?
mode = 'load' # or 'save'
if mode == 'save':
myVar = tf.Variable(7.1)
init_op = tf.global_variables_initializer()
saver0 = tf.train.Saver()
tf.add_to_collection('myVar', myVar) ### WHY NECESSARY?
with tf.Session() as sess:
sess.run(init_op)
print sess.run(myVar)
saver0.save(sess, './myModel')
if mode == 'load':
with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.meta')
saver1.restore(sess, tf.train.latest_checkpoint('./'))
myVar = tf.get_collection('myVar')[0] ### WHY NECESSARY?
print sess.run(myVar)
Upvotes: 1
Views: 6728
Reputation: 1129
I've been trying to figure out the same thing and was able to successfully do it by using Supervisor
. It automatically loads all variables and your graph etc. Here is the documentation - https://www.tensorflow.org/programmers_guide/supervisor. Below is my code -
sv = tf.train.Supervisor(logdir="/checkpoint', save_model_secs=60)
with sv.managed_session() as sess:
if not sv.should_stop():
#Do run/eval/train ops on sess as needed. Above works for both saving and loading
As you see, this is much simpler than using the Saver
object and dealing with individual variables etc as long as the graph stays the same (my understanding is that Saver
comes handy when we want to reuse a pre-trained model for a different graph).
Upvotes: 1
Reputation: 6689
Question1
This question has been already answered thoroughly here. You don't have to explicitly call export_meta_graph
. Call the save method
. This will generate the .meta
file also (since save method will call the export_meta_graph
method internally.)
For example
saver0.save(sess, './myModel.ckpt')
will produce myModel.ckpt
file and also the myModel.ckpt.meta
file.
Then you can restore the model using
with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.ckpt.meta')
saver1.restore(sess, './myModel')
print sess.run(myVar)
Question2
Collections are used to store custom information like learning rate,the regularisation factor that you have used and other information and these will be stored when you export the graph. Tensorflow itself defines some collections like "TRAINABLE_VARIABLES" which are used to get all the trainable variables of the model you built. You can chose to export all the collections in your graph or you can specify which collections to export in the export_meta_graph
function.
Yes tensorflow will export all the variables that you define. But if you need any other information that needs to be exported to the graph then they can be added to the collection.
Upvotes: 2