Reputation: 7140
Please see this toy model:
import tensorflow as tf
import os
if not os.path.isdir('./temp'):
os.mkdir('./temp')
def create_and_save_varialbe(sess=tf.Session()):
a = tf.get_variable("a", [])
saver_a = tf.train.Saver({"a": a})
init = tf.global_variables_initializer()
sess.run(init)
saver_a.save(sess, './temp/temp_model')
a = sess.run(a)
print('the initialized a is %f' % a)
return a
def init_variable(sess=tf.Session()):
b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32)
tf.train.init_from_checkpoint('./temp/temp_model',
{'a': 'b'})
init = tf.global_variables_initializer()
sess.run(init)
b = sess.run(b)
print(b)
return b
def init_get_variable(sess=tf.Session()):
c = tf.get_variable("c", shape=[])
tf.train.init_from_checkpoint('./temp/temp_model',
{'a': 'c'})
init = tf.global_variables_initializer()
sess.run(init)
c = sess.run(c)
print(c)
return c
a = create_and_save_varialbe()
b = init_variable()
c = init_get_variable()
The function init_get_varialbe works but not the function init_variable.
ValueError: Assignment map with scope only name should map to scope only a. Should be 'scope/': 'other_scope/'.
Why doesn't the name of variable defined by Variable work in this scenario and how can I tackle it?
Tensorflow version: 1.12
Upvotes: 0
Views: 966
Reputation: 7140
This is because of the difference between Variable and get_variable.
There are two ways to tackle it:
1) input the variable other than the name of it.
def init_variable(sess=tf.Session()):
b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32)
tf.train.init_from_checkpoint('./temp/temp_model',
{'a': b})
init = tf.global_variables_initializer()
sess.run(init)
b = sess.run(b)
print(b)
return b
Because if it is variable tensorflow can get it directly:
if _is_variable(current_var_or_name) or (
isinstance(current_var_or_name, list)
and all(_is_variable(v) for v in current_var_or_name)):
var = current_var_or_name
Otherwise it should get the variable from variable store:
store_vars = vs._get_default_variable_store()._vars
But the variable defined by Variable is not in the ('varstore_key',)
collection as explained in this answer.
Then 2) you can add it to the collection yourself:
from tensorflow.python.ops.variable_scope import _VariableStore
from tensorflow.python.framework import ops
def init_variable(sess=tf.Session()):
b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32)
store = _VariableStore()
store._vars = {'b': b}
ops.add_to_collection(('__variable_store',), store)
tf.train.init_from_checkpoint('./temp/temp_model',
{'a': 'b'})
init = tf.global_variables_initializer()
sess.run(init)
b = sess.run(b)
print(b)
return b
Both work.
Upvotes: 1