luckcul
luckcul

Reputation: 35

Why TensorFlow can not restore a variable initialized by a constant?

I am new to TensorFlow. When I read tensorflow saving and restoring variables manual, I encountered a problem. I saved a variable initialized by a constant, but I can not restore the variable. The code is as following:

a = tf.get_variable("name_a", initializer=[1,2,3])
op1 = a.assign(a+1)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    op1.op.run()
    print(a.eval())
    saver.save(sess,"log1/model.ckpt")

Then I restore it.

a = tf.get_variable("name_a", shape=[3])
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "log1/model.ckpt")
    print(a.eval())

I want to get output like [2,3,4], but I got [ 2.80259693e-45 4.20389539e-45 5.60519386e-45]. It's all zeros.

However, when I modify the first line in the first code snippet to

a = tf.get_variable("name_a", initializer=tf.zeros([3]))

I can get the right restored variable: [ 1. 1. 1.]

I wonder the reason for this situation.

Upvotes: 1

Views: 1125

Answers (2)

Maxim
Maxim

Reputation: 53758

I'm not 100% sure, but it looks like the reason is that your two variables:

  • tf.get_variable("name_a", initializer=[1,2,3])

  • tf.get_variable("name_a", shape=[3])

are not equivalent and can't be used one for another that easily (Update: the dtype is different, thanks @BlueSun for noticing this).

You will have a stable output if you define the tensors in restore code just like in saving: a = tf.get_variable("name_a", initializer=[1,2,3]). However, even better would be to work with the saved graph directly:

saver = tf.train.import_meta_graph('log1/model.ckpt.meta')
with tf.Session() as sess:
  saver.restore(sess, "log1/model.ckpt")
  saved = sess.graph.get_tensor_by_name('name_a:0')
  print(sess.run(saved))

Which works correctly no matter how you define the initializer.

Upvotes: 1

BlueSun
BlueSun

Reputation: 3570

You have to define the variable a with the same data type. If you don't specify it and don't have any initializer, the dtype will be tf.float32 by default and the loading of tf.int32 will fail. Simply setting the data type to int32 will solve the problem:

a = tf.get_variable("name_a", shape=[3], dtype=tf.int32)
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "log1/model.ckpt")
    print(a.eval())

Using a = tf.get_variable("name_a", initializer=tf.zeros([3])) worked because tf.zeros([3]) has the same dtype as [2, 3, 4]. It is safer to always set the dtype whenever you create a variables.

Upvotes: 1

Related Questions