Reputation: 182
In Saver documentation, it is stated that the Saver object can take either a list or dictionary as input and in case of dictionaris, the keys must be the names which will be used to save or restore variables. I have a code which looks like the following:
create_network()
vars_to_load_list = ...
vars_to_load_dict = {v.name:v for v in vars_to_load_list}
loader = tf.train.Saver(var_list=vars_to_load_list, max_to_keep=FLAGS.max_epoch)
path = ...
latest_ckpt = tf.train.latest_checkpoint(path, latest_filename=None)
sess = tf.Session()
ckpt = tf.train.get_checkpoint_state(path)
if ckpt and ckpt.model_checkpoint_path:
loader.restore(sess, save_path=latest_ckpt)
The above code works, but if I pass in the variables dictionary instead of the varible list, i.e. change the defenition of loader
to:
loader = tf.train.Saver(var_list=vars_to_load_dict, max_to_keep=FLAGS.max_epoch)
Then I get a NotFoundError
and the loader complains that some Tensor names were not found in checkpoint files. But I expect both versions of the code to work the same. Am I missing something?
Upvotes: 0
Views: 1046
Reputation: 182
I figured out the problem. Apparently, the name attribute of a variable corresponds to the value of the variable and not its tensor (if my understanding of these concepts are right). i.e. it returns "my_var:0"
whereas the loader requires "my_var"
. Modifying the definition of the dictionary in the above example solves the problem:
vars_to_load_dict = {v.name[:-2]:v for v in vars_to_load_list}
Upvotes: 4