Phoenix666
Phoenix666

Reputation: 182

passing dictionaries to tensorflow Saver

In Saver documentation, it is stated that the Saver object can take either a list or dictionary as input and in case of dictionaris, the keys must be the names which will be used to save or restore variables. I have a code which looks like the following:

create_network()
vars_to_load_list = ...
vars_to_load_dict = {v.name:v for v in vars_to_load_list}
loader = tf.train.Saver(var_list=vars_to_load_list, max_to_keep=FLAGS.max_epoch)
path = ...
latest_ckpt = tf.train.latest_checkpoint(path, latest_filename=None)
sess = tf.Session()
ckpt = tf.train.get_checkpoint_state(path)
if ckpt and ckpt.model_checkpoint_path:
  loader.restore(sess, save_path=latest_ckpt)

The above code works, but if I pass in the variables dictionary instead of the varible list, i.e. change the defenition of loader to:

loader = tf.train.Saver(var_list=vars_to_load_dict, max_to_keep=FLAGS.max_epoch)

Then I get a NotFoundError and the loader complains that some Tensor names were not found in checkpoint files. But I expect both versions of the code to work the same. Am I missing something?

Upvotes: 0

Views: 1046

Answers (1)

Phoenix666
Phoenix666

Reputation: 182

I figured out the problem. Apparently, the name attribute of a variable corresponds to the value of the variable and not its tensor (if my understanding of these concepts are right). i.e. it returns "my_var:0" whereas the loader requires "my_var". Modifying the definition of the dictionary in the above example solves the problem:

vars_to_load_dict = {v.name[:-2]:v for v in vars_to_load_list}

Upvotes: 4

Related Questions