Reputation: 2893
Let's say I have two identical networks, A
and B
. I saved (using Saver
) a previous state of network A
, and now I would like to load it into network B
(all happens during the same run). How can I do this?
Upvotes: 1
Views: 138
Reputation: 11895
Let me provide an example. First, let's define and save some variables:
import tensorflow as tf
v1 = tf.Variable(tf.ones(1), name='v1')
v2 = tf.Variable(2 * tf.ones(1), name='v2')
saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, './tmp.ckpt')
Now, let's define some variables with the same names in a new graph, and load their values from the checkpoint:
with tf.Graph().as_default():
assert len(tf.trainable_variables()) == 0
v1 = tf.Variable(tf.zeros(1), name='v1')
v2 = tf.Variable(tf.zeros(1), name='v2')
saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, './tmp.ckpt')
print(sess.run([v1, v2]))
The last line prints:
[array([1.], dtype=float32), array([2.], dtype=float32)]
Upvotes: 1