Tiago Freitas Pereira
Tiago Freitas Pereira

Reputation: 690

The function tf.contrib.framework.init_from_checkpoint does not work properly

I'm having issues with the function tf.contrib.framework.init_from_checkpoint. It simply does not work (very likely I'm doing something wrong). I crafted the example below to demonstrate the behaviour:

import tensorflow as tf
model_name = "./my_model.ckp"

### MY MODEL IS COMPOSED BY 2 VARIABLES
with tf.variable_scope("A"):
    A = tf.Variable([1, 2, 3], name="A1")

with tf.variable_scope("B"):
    B = tf.Variable([4, 5, 6], name="B1")

# INITIALIZING AND SAVING THE MODEL    
with tf.Session() as sess:
    tf.global_variables_initializer().run(session=sess)
    print(sess.run([A, B]))

    saver = tf.train.Saver()
    saver.save(sess, model_name)

#### CLEANING UP
tf.reset_default_graph()


### CREATING OTHER "MODEL"
with tf.variable_scope("C"):
    A = tf.Variable([0, 0, 0], name="A1")

with tf.variable_scope("B"):
    B = tf.Variable([0, 0, 0], name="B1")

# MAPPING THE VARIABLES FROM MY CHECKPOINT TO MY NEW SET OF VARIABLES
tf.contrib.framework.init_from_checkpoint(
    model_name,
    {"A/": "C/", 
    "B/": "B/"})

with tf.Session() as sess:
    tf.global_variables_initializer().run(session=sess)
    print(sess.run([A, B]))

The output is: [array([1, 2, 3], dtype=int32), array([4, 5, 6], dtype=int32)] --> which is expected and [array([0, 0, 0], dtype=int32), array([0, 0, 0], dtype=int32)], which is not expected.

What's going on?

Thanks

Upvotes: 0

Views: 578

Answers (1)

LKS
LKS

Reputation: 673

The problem is that you are using low level method Variable to create variable so the it is not stored in variable store.

In your ### CREATING OTHER "MODEL", if you make the following changes:

with tf.variable_scope("C"): A = tf.get_variable(name='A1', initializer=[0,0,0])

Then I had tested it can successfully restore from checkpoint.

Upvotes: 1

Related Questions