sandboxj
sandboxj

Reputation: 1254

Tensorflow: How to have saver.save() and .restore() in one module?

I have a module called neural.py

I initialize the variables in the body.

import tensorflow as tf 
tf_x = tf.placeholder(tf.float32, [None, length])
tf_y = tf.placeholder(tf.float32, [None, num_classes])
...

I save the checkpoint in a function train() after training:

def train():
    ...
    pred = tf.layers.dense(dropout, num_classes, tf.identity) 
    ...
    cross_entropy = tf.losses.softmax_cross_entropy(tf_y, pred)
    ...
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        saver = tf.train.Saver(tf.trainable_variables())
        for ep in range(epochs):
        ... (training steps)...
    saver.save(sess, "checkpoints/cnn") 

I want to also restore and run the network after training in the run() function of this module:

def run():
    # I have tried adding tf.reset_default_graph() here
    # I have also tried with tf.Graph().as_default() as g: and adding (graph=g) in tf.Session()

    saver = tf.train.Saver()

    with tf.Session() as sess:
        saver.restore(sess, "checkpoints/cnn")
        ... (run network etc)

It just doesn't work. It gives me either NotFoundError (see above for traceback): Key beta2_power not found in checkpoint or ValueError: No variables to save if I add tf.reset_default_graph() under run(), as commented above.

However, if I put the exact same code for run() in a new module without train() and with tf.reset_default_graph() at the top, it works perfectly. How do I make it work in the same module?

Final snippet:

if __name__ == '__main__':
    print("Start training")
    train()
    print("Finished training. Generate prediction")
    run()

Upvotes: 1

Views: 1207

Answers (1)

Maxim
Maxim

Reputation: 53766

This might be a typo, but saver.save(sess, "checkpoints/cnn") should definitely be within with tf.Session() as sess block, otherwise you're saving a closed session.

NotFoundError (see above for traceback): Key beta2_power not found in checkpoint

I think the problem is that part of your graph is defined in train. The beta1_power and beta2_power are the internal variables of AdapOptimizer, which, along with pred and softmax_cross_entropy, is not in the graph, if train() is not invoked (e.g. commented?). So one solution would be to make the whole graph accessible in both train and run.

Another solution is to separate them and use the restored graph in run, instead of default one. Like this:

tf.reset_default_graph()
saver = tf.train.import_meta_graph('checkpoints/cnn.meta')
with tf.Session() as sess:
  saver.restore(sess, "checkpoints/cnn")
  print("Model restored.")
  tf_x = sess.graph.get_tensor_by_name('tf_x:0')
  ...

But you'll need to give the names to all of your variables (good idea anyway) and then find those tensors in the graph. Can't use previously defined variables here. This method assures that run method works with the saved model version, can be easily extracted in a separate script, etc.

Upvotes: 1

Related Questions