Reputation: 690
I'm having issues with the function tf.contrib.framework.init_from_checkpoint
. It simply does not work (very likely I'm doing something wrong). I crafted the example below to demonstrate the behaviour:
import tensorflow as tf
model_name = "./my_model.ckp"
### MY MODEL IS COMPOSED BY 2 VARIABLES
with tf.variable_scope("A"):
A = tf.Variable([1, 2, 3], name="A1")
with tf.variable_scope("B"):
B = tf.Variable([4, 5, 6], name="B1")
# INITIALIZING AND SAVING THE MODEL
with tf.Session() as sess:
tf.global_variables_initializer().run(session=sess)
print(sess.run([A, B]))
saver = tf.train.Saver()
saver.save(sess, model_name)
#### CLEANING UP
tf.reset_default_graph()
### CREATING OTHER "MODEL"
with tf.variable_scope("C"):
A = tf.Variable([0, 0, 0], name="A1")
with tf.variable_scope("B"):
B = tf.Variable([0, 0, 0], name="B1")
# MAPPING THE VARIABLES FROM MY CHECKPOINT TO MY NEW SET OF VARIABLES
tf.contrib.framework.init_from_checkpoint(
model_name,
{"A/": "C/",
"B/": "B/"})
with tf.Session() as sess:
tf.global_variables_initializer().run(session=sess)
print(sess.run([A, B]))
The output is: [array([1, 2, 3], dtype=int32), array([4, 5, 6], dtype=int32)] --> which is expected and [array([0, 0, 0], dtype=int32), array([0, 0, 0], dtype=int32)], which is not expected.
What's going on?
Thanks
Upvotes: 0
Views: 578
Reputation: 673
The problem is that you are using low level method Variable
to create variable so the it is not stored in variable store.
In your ### CREATING OTHER "MODEL"
, if you make the following changes:
with tf.variable_scope("C"):
A = tf.get_variable(name='A1', initializer=[0,0,0])
Then I had tested it can successfully restore from checkpoint.
Upvotes: 1