user270700
user270700

Reputation: 759

How to restore variables using CheckpointReader in Tensorflow

I'm trying to restore some variables from checkpoint file if same variable name is in current model.
And I found that there is some way as in Tensorfow Github

So what I want to do is checking variable names in checkpoint file using has_tensor("variable.name") as below,

...    
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
    print v.name
    if reader.has_tensor(v.name):
        print 'has tensor'
...

But I found that v.name returns both variable name and colon+number. For example, I have variable name W_o and b_o then v.name returns W_o:0, b_o:0.

However reader.has_tensor() requires name without colon and number as W_o, b_o.

My question is: how to remove the colon and number at the end of the variable name in order to read the variables?
Is there a better way to restore such variables?

Upvotes: 3

Views: 8122

Answers (3)

Vishal Keshav
Vishal Keshav

Reputation: 1

Simple answer:

reader = tf.train.NewCheckpointReader(checkpoint_file)

variable1 = reader.get_tensor('layer_name1/layer_type_name')
variable2 = reader.get_tensor('layer_name2/layer_type_name')

Now, after modification to these variables, you can assign it back.

layer_name1_var.set_weights([variable1, variable2])

Upvotes: 0

Alexander Ponamarev
Alexander Ponamarev

Reputation: 66

tf.train.NewCheckpointReader is a nifty method that creates a CheckpointReader object. CheckpointReader has several very useful methods. The method that would be the most relevant to your question would be get_variable_to_shape_map().

  • get_variable_to_shape_map() provides a dictionary with variable names and shapes:

saved_shapes = reader.get_variable_to_shape_map()
print 'fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels']

Please take a look at this quick tutorial below: Loading Variables from Existing Checkpoints

Upvotes: 1

rvinas
rvinas

Reputation: 11895

You could use string.split() to get the tensor name:

...    
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
    tensor_name = v.name.split(':')[0]
    print tensor_name
    if reader.has_tensor(tensor_name):
        print 'has tensor'
...

Next, let me use an example to show how I would restore every possible variable from a .cpkt file. First, let's save v2 and v3 in tmp.ckpt:

import tensorflow as tf

v1 = tf.Variable(tf.ones([1]), name='v1')
v2 = tf.Variable(2 * tf.ones([1]), name='v2')
v3 = tf.Variable(3 * tf.ones([1]), name='v3')

saver = tf.train.Saver({'v2': v2, 'v3': v3})

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    saver.save(sess, 'tmp.ckpt')

That's how I would restore every variable (belonging to a new graph) showing up in tmp.ckpt:

with tf.Graph().as_default():
    assert len(tf.trainable_variables()) == 0
    v1 = tf.Variable(tf.zeros([1]), name='v1')
    v2 = tf.Variable(tf.zeros([1]), name='v2')

    reader = tf.train.NewCheckpointReader('tmp.ckpt')
    restore_dict = dict()
    for v in tf.trainable_variables():
        tensor_name = v.name.split(':')[0]
        if reader.has_tensor(tensor_name):
            print('has tensor ', tensor_name)
            restore_dict[tensor_name] = v

    saver = tf.train.Saver(restore_dict)
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        saver.restore(sess, 'tmp.ckpt')
        print(sess.run([v1, v2])) # prints [array([ 0.], dtype=float32), array([ 2.], dtype=float32)]

Also, you may want to ensure that shapes and dtypes match.

Upvotes: 6

Related Questions