Lerner Zhang
Lerner Zhang

Reputation: 7140

Variable name doesn't work for init_from_checkpoint?

Please see this toy model:

import tensorflow as tf
import os

if not os.path.isdir('./temp'):
    os.mkdir('./temp')


def create_and_save_varialbe(sess=tf.Session()):
    a = tf.get_variable("a", [])
    saver_a = tf.train.Saver({"a": a})
    init = tf.global_variables_initializer()
    sess.run(init)
    saver_a.save(sess, './temp/temp_model')
    a = sess.run(a)
    print('the initialized a is %f' % a)
    return a


def init_variable(sess=tf.Session()):
    b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32) 
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': 'b'})
    init = tf.global_variables_initializer()
    sess.run(init)
    b = sess.run(b)
    print(b)
    return b


def init_get_variable(sess=tf.Session()):
    c = tf.get_variable("c", shape=[])
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': 'c'})
    init = tf.global_variables_initializer()
    sess.run(init)
    c = sess.run(c)
    print(c)
    return c


a = create_and_save_varialbe()
b = init_variable()
c = init_get_variable()

The function init_get_varialbe works but not the function init_variable.

ValueError: Assignment map with scope only name should map to scope only a. Should be 'scope/': 'other_scope/'.

Why doesn't the name of variable defined by Variable work in this scenario and how can I tackle it?

Tensorflow version: 1.12

Upvotes: 0

Views: 966

Answers (1)

Lerner Zhang
Lerner Zhang

Reputation: 7140

This is because of the difference between Variable and get_variable.

There are two ways to tackle it:

1) input the variable other than the name of it.

def init_variable(sess=tf.Session()):
    b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32) 
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': b})
    init = tf.global_variables_initializer()
    sess.run(init)
    b = sess.run(b)
    print(b)
    return b

Because if it is variable tensorflow can get it directly:

if _is_variable(current_var_or_name) or (
    isinstance(current_var_or_name, list)
    and all(_is_variable(v) for v in current_var_or_name)):
  var = current_var_or_name

Otherwise it should get the variable from variable store:

  store_vars = vs._get_default_variable_store()._vars 

But the variable defined by Variable is not in the ('varstore_key',) collection as explained in this answer.

Then 2) you can add it to the collection yourself:

from tensorflow.python.ops.variable_scope import _VariableStore
from tensorflow.python.framework import ops
def init_variable(sess=tf.Session()):
    b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32) 
    store = _VariableStore()
    store._vars = {'b': b}
    ops.add_to_collection(('__variable_store',), store)
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': 'b'})
    init = tf.global_variables_initializer()
    sess.run(init)
    b = sess.run(b)
    print(b)
    return b

Both work.

Upvotes: 1

Related Questions