Reputation: 1254
I have a module called neural.py
I initialize the variables in the body.
import tensorflow as tf
tf_x = tf.placeholder(tf.float32, [None, length])
tf_y = tf.placeholder(tf.float32, [None, num_classes])
...
I save the checkpoint in a function train()
after training:
def train():
...
pred = tf.layers.dense(dropout, num_classes, tf.identity)
...
cross_entropy = tf.losses.softmax_cross_entropy(tf_y, pred)
...
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
saver = tf.train.Saver(tf.trainable_variables())
for ep in range(epochs):
... (training steps)...
saver.save(sess, "checkpoints/cnn")
I want to also restore and run the network after training in the run()
function of this module:
def run():
# I have tried adding tf.reset_default_graph() here
# I have also tried with tf.Graph().as_default() as g: and adding (graph=g) in tf.Session()
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "checkpoints/cnn")
... (run network etc)
It just doesn't work. It gives me either NotFoundError (see above for traceback): Key beta2_power not found in checkpoint
or ValueError: No variables to save
if I add tf.reset_default_graph()
under run()
, as commented above.
However, if I put the exact same code for run()
in a new module without train()
and with tf.reset_default_graph()
at the top, it works perfectly. How do I make it work in the same module?
Final snippet:
if __name__ == '__main__':
print("Start training")
train()
print("Finished training. Generate prediction")
run()
Upvotes: 1
Views: 1207
Reputation: 53766
This might be a typo, but saver.save(sess, "checkpoints/cnn")
should definitely be within with tf.Session() as sess
block, otherwise you're saving a closed session.
NotFoundError (see above for traceback): Key beta2_power not found in checkpoint
I think the problem is that part of your graph is defined in train
. The beta1_power
and beta2_power
are the internal variables of AdapOptimizer
, which, along with pred
and softmax_cross_entropy
, is not in the graph, if train()
is not invoked (e.g. commented?). So one solution would be to make the whole graph accessible in both train
and run
.
Another solution is to separate them and use the restored graph in run
, instead of default one. Like this:
tf.reset_default_graph()
saver = tf.train.import_meta_graph('checkpoints/cnn.meta')
with tf.Session() as sess:
saver.restore(sess, "checkpoints/cnn")
print("Model restored.")
tf_x = sess.graph.get_tensor_by_name('tf_x:0')
...
But you'll need to give the names to all of your variables (good idea anyway) and then find those tensors in the graph. Can't use previously defined variables here. This method assures that run
method works with the saved model version, can be easily extracted in a separate script, etc.
Upvotes: 1