florpi
florpi

Reputation: 177

Restore to keep training not working with tensorflow

I need to save and restore the graph to keep training from the last checkpoint, but somehow is not working.

I use saver = tf.train.Saver() to save the model. And:

with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
    # Initializing saver
    sess.run(tf.global_variables_initializer())
    save_path = saver.save(sess,model_path+"/%s.ckpt"%model_name)
    if flag == "initial_train":
        training_loop(num_epochs)
        flag = None
    else:
        new_saver = tf.train.import_meta_graph(model_path+"/%s.ckpt.meta"%model_name)
        new_saver.restore(sess, save_path)
        print("Model loaded")
        training_loop(num_epochs)

I really don't know why it's not importing the weights

Upvotes: 1

Views: 471

Answers (1)

etarion
etarion

Reputation: 17159

You are, on subsequent runs

  1. Initializing all variables (so they will have their initial random/constant values) with sess.run(tf.global_variables_initializer()
  2. Saving the initialized values to some file (saver.save(sess,model_path+"/%s.ckpt"%model_name))
  3. Loading those randomly initialized values from that file

So you are just loading what you initialized and saved on line 3&4.

Also, I don't know how you pass information, but training_loop does not get a reference to the saver and you are not saving the model after the training loop, so it seems you are not actually saving your models anywhere.

Upvotes: 1

Related Questions