Reputation: 506
I just updated my local installation of Tensorflow to 0.11rc2 and I got a message saying that I should add a parameter to my saver to make it save in version 2. I updated this and now I cannot load models that were saved in this format. When I run my model, it saves after every epoch. When it saves, it used to save files called translate.ckpt-3916
and translate.ckpt-3916.meta
. Now I get three files instead of two, named translate.ckpt-3916.index
, translate.ckpt-3916.meta
, and translate.ckpt-3916.data-000000-of-000001
.
To load data, I use the following code:
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
model.saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
session.run(tf.initialize_all_variables())
return model
Where model
is a model object that was already initialized with the standard hyperparameters of my program. This works without issue with saver v1. ckpt.model_checkpoint_path
evaluates to the path to translate.ckpt-3916
regardless of version, so if the checkpoint was saved with v2, no file is found.
The contents of the checkpoint
file in that directory (when saved with either version) are:
model_checkpoint_path: "translate.ckpt-3916"
all_model_checkpoint_paths: "translate.ckpt-3916"
Is there a new method to load data with saver v2? Otherwise, how can I load my checkpoints?
EDIT:
Changing the line if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
to if ckpt and ckpt.model_checkpoint_path:
like is shown in this question seems to work a little further but then throws the following error:
InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [84] rhs shape= [98]
[[Node: save/Assign_54 = Assign[T=DT_FLOAT, _class=["loc:@NLC/Logistic/Linear/Bias"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](NLC/Logistic/Linear/Bias, save/RestoreV2_54)]]
Upvotes: 2
Views: 2172
Reputation: 506
The method I posted in my edit was actually the correct way to get this to work. The error I got was because the data had changed between when I made the checkpoint and when I tried to load it.
Just to make it visible, loading from a V2 checkpoint in the code above was done by changing the line if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
to if ckpt and ckpt.model_checkpoint_path:
Upvotes: 2