Reputation: 65
There are two python files, The first one is for saving the tensorflow model. The second one is for restoring the saved model.
Question:
When I run the two files one after another, it's ok.
When I run the first one, restart the edit and run the second one,it tells me that the w1 is not defined?
What I want to do is:
Save a tensorflow model
Restore the saved model
What's wrong with it? Thanks for your kindly help?
model_save.py
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, 'SR\\my-model')
model_restore.py
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph('SR\\my-model.meta')
saver.restore(sess,'SR\\my-model')
print (sess.run(w1))
Upvotes: 3
Views: 4633
Reputation: 136
Briefly, you should use
print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0')))
instead of print (sess.run(w1))
in your model_restore.py file.
model_save.py
import tensorflow as tf
w1_node = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2_node = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(w1_node.eval()) # [ 0.43350926 1.02784836]
#print(w1.eval()) # NameError: name 'w1' is not defined
saver.save(sess, 'my-model')
w1_node
is only defined in model_save.py, and model_restore.py file can't recognize it.
When we call a Tensor
variable by its name
, we should use get_tensor_by_name
, as this post Tensorflow: How to get a tensor by name? suggested.
model_restore.py
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph('my-model.meta')
saver.restore(sess,'my-model')
print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0')))
# [ 0.43350926 1.02784836]
print(tf.global_variables()) # print tensor variables
# [<tf.Variable 'w1:0' shape=(2,) dtype=float32_ref>,
# <tf.Variable 'w2:0' shape=(5,) dtype=float32_ref>]
for op in tf.get_default_graph().get_operations():
print str(op.name) # print all the operation nodes' name
Upvotes: 4