shakedzy
shakedzy

Reputation: 2893

Tensorflow: how to save variables and load them to different variables?

Let's say I have two identical networks, A and B. I saved (using Saver) a previous state of network A, and now I would like to load it into network B (all happens during the same run). How can I do this?

Upvotes: 1

Views: 138

Answers (1)

rvinas
rvinas

Reputation: 11895

Let me provide an example. First, let's define and save some variables:

import tensorflow as tf

v1 = tf.Variable(tf.ones(1), name='v1')
v2 = tf.Variable(2 * tf.ones(1), name='v2')

saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './tmp.ckpt')

Now, let's define some variables with the same names in a new graph, and load their values from the checkpoint:

with tf.Graph().as_default():
    assert len(tf.trainable_variables()) == 0
    v1 = tf.Variable(tf.zeros(1), name='v1')
    v2 = tf.Variable(tf.zeros(1), name='v2')

    saver = tf.train.Saver(tf.trainable_variables())
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, './tmp.ckpt')
        print(sess.run([v1, v2]))

The last line prints:

[array([1.], dtype=float32), array([2.], dtype=float32)]

Upvotes: 1

Related Questions