lina
lina

Reputation: 293

tensorflow saves only the initialized value

I am trying to save some variables and see if I could restore it later. Here is my saving code:

   import tensorflow as tf;
   my_a = tf.Variable(2,name = "my_a");
   my_b = tf.Variable(3,name = "my_b");
   my_c = tf.Variable(4,name = "my_c");
   my_c = tf.add(my_a,my_b);

   with tf.Session() as sess:
       init = tf.initialize_all_variables();
       sess.run(init);
       print("my_c =  ",sess.run(my_c));
       saver = tf.train.Saver();
       saver.save(sess,"test.ckpt");

This prints out:

    my_c =   5

As I restore it:

   import tensorflow as tf;
   c = tf.Variable(3100,dtype = tf.int32);
   with tf.Session() as sess:
       sess.run(tf.initialize_all_variables());
       saver = tf.train.Saver({"my_c":c});
       saver.restore(sess, "test.ckpt");
       cc= sess.run(c);
       print(cc);

This gives me :

    4

The restored value of my_c should be 5 since it is the sum of my_a and my_b. However it gives me 4, which is the initialized value of my_c. Could anyone explain why that happens, and how to save the changes to a variable?

Upvotes: 1

Views: 209

Answers (1)

martianwars
martianwars

Reputation: 6500

In your original code, you have not really assigned the variable named my_c (mind you, TensorFlow name) to my_a + my_b.

By writing my_c = tf.add(my_a,my_b), the python variable my_c is now different from the tf.Variable having name='my_c'.

When you execute sess.run(), you are just executing the operation, not updating that variable.

If you want this code to run correctly, use this instead - (see the comments for changes)

import tensorflow as tf
my_a = tf.Variable(2,name = "my_a")
my_b = tf.Variable(3,name = "my_b")
my_c = tf.Variable(4,name="my_c")
# Use the assign() function to set the new value
add = my_c.assign(tf.add(my_a,my_b))

with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)
    # Execute the add operator
    sess.run(add)
    print("my_c =  ",sess.run(my_c))
    saver = tf.train.Saver()
    saver.save(sess,"test.ckpt")

Upvotes: 2

Related Questions