Reputation: 177
I need to save and restore the graph to keep training from the last checkpoint, but somehow is not working.
I use saver = tf.train.Saver()
to save the model. And:
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
# Initializing saver
sess.run(tf.global_variables_initializer())
save_path = saver.save(sess,model_path+"/%s.ckpt"%model_name)
if flag == "initial_train":
training_loop(num_epochs)
flag = None
else:
new_saver = tf.train.import_meta_graph(model_path+"/%s.ckpt.meta"%model_name)
new_saver.restore(sess, save_path)
print("Model loaded")
training_loop(num_epochs)
I really don't know why it's not importing the weights
Upvotes: 1
Views: 471
Reputation: 17159
You are, on subsequent runs
sess.run(tf.global_variables_initializer()
saver.save(sess,model_path+"/%s.ckpt"%model_name)
)So you are just loading what you initialized and saved on line 3&4.
Also, I don't know how you pass information, but training_loop does not get a reference to the saver
and you are not saving the model after the training loop, so it seems you are not actually saving your models anywhere.
Upvotes: 1