Reputation: 73
I want to use a pretrained tensorflow model provided by an unknown author. I do not know how he/she managed to save the tensorflow model (he/she used tensorflow version >= 1.2) to only one file with the extension '.model', as normally I get either three files '.meta', '.data', '.index' or one file with '.ckpt'.
How can I restore this pretrained model? How can I save a model to this format later?
Thanks.
Upvotes: 3
Views: 1252
Reputation: 1372
I have also asked this question on a number of platforms with no assistance yet. So I decided to do some experimental work and this is what I found. This may be long but please bear with me.
To import a model in Tensor-flow we use
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
The .meta
file contains all the variables, operations, collections, etc, of the trained model. What tf.train.latest_checkpoint('./')
does is to use the checkpoint file (which simply keeps a record of latest checkpoint files saved) to import the xxxx_model.data-00000-of-00001
. This .data-00000-of-00001
contains all the weights, biases, gradients, etc, that must be loaded into the variables contained in my_test_model-1000.meta
.
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
#new_saver.restore(sess, tf.train.latest_checkpoint('./'))
tensor_variable = tf.trainable_variables()
for tensor_var in tensor_variable:
#print(sess.run(tensor_var))
print(tensor_var)
This initial code will print out all the variables from .meta
that are trainable. If you try to run print(sess.run(tensor_var))
you will get an error. This is because, the variables have not been initialized. How ever, if you un-comment new_saver.restore(sess, tf.train.latest_checkpoint('./'))
and run print(sess.run(tensor_var))
, you will get all the variables alongside values loaded into the variables.
My best guess is that xxxxxx.model
works a much like xxxx_model.data-00000-of-00001
from tensorflow. It does not contain variables and so if you try to do
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('xxx.model')
you will get an error. Remember, the reason is that, this .model
file does not contain any variables nor operation graph of any form. If you also try to do
with tf.Session() as sess:
new_saver = tf.train.Saver()
new_saver.restore(sess, "xxxx.model")
you will similarly get an error. This is because, there are no corresponding variables to load values into. Therefore, if you ever obtain a xxx.model
file, you will have to go through the pain of replicating all the variables and operations before trying to run new_saver.restore(sess, "xxxx.model")
. If you are able to replicate the architecture, this will run smoothly with no issues, hopefully.
I am sorry this was long, but considering that there is almost no answer on the internet, I had to make a lecture out of it. :)
Upvotes: 2