BiBi
BiBi

Reputation: 7908

Modifying shape of tensor in tensorflow checkpoint

I have a tensorflow checkpoint that I'm able to load after redefining the graph corresponding to it with the regular routine tf.train.Saver() and saver.restore(session, 'my_checkpoint.ckpt').

However, now, I would like to modify the first layer of the network to accept an input of shape say [200, 200, 1] instead of [200, 200, 10].

To this end, I would like to modify the shape of the tensor corresponding to the first layer from [3, 3, 10, 32] (3x3 kernel, 10 input channels, 32 output channels) to [3, 3, 1, 32] by summing across the 3rd dimension.

How could I do that?

Upvotes: 3

Views: 1341

Answers (2)

jawen zhang
jawen zhang

Reputation: 11

you can use tensorflow::BundleReader read source ckpt, and use tensorflow::BundleWriter to rewrite it.

tensorflow::BundleReader reader(Env::Default(), model_path_prefix);
std::vector<std::string> tensor_names;
reader.Seek("");
reader.Next();
for (; reader.Valid(); reader.Next()) {
    tensor_names.emplace_back(reader.key());
}
tensorflow::BundleWriter writer(Env::Default(), new_model_path_prefix);   
for (auto &tensor_name : tensor_names) {
        DataType dtype;
        TensorShape shape;        
        
        reader.LookupDtypeAndShape(tensor_name, &dtype, &shape);
        Tensor val(dtype, shape);
        Status bool_ret  = reader.Lookup(tensor_name, &val);
        std::cout << tensor_name << " " << val.DebugString() << std::endl;
        // modify dtype and shape. padding Tensor
        TensorSlice slice(new_shape.dims());
        writer.AddSlice(tensor_name, new_shape, slice, new_val);
    }
}
writer.Finish();

Upvotes: 1

BiBi
BiBi

Reputation: 7908

I found a way to do it, but in a not so straightforward way. Given a checkpoint, we can convert it to a serialized numpy array (or any other format that we might find suitable to save a dictionary of numpy arrays) as follow:

checkpoint = {}
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, 'my_checkpoint.ckpt')
    for x in tf.global_variables():
        checkpoint[x.name] = x.eval()
    np.save('checkpoint.npy', checkpoint)

There might be some exceptions to handle but let's keep the code simple.

Then, we can do whichever operations we like on the numpy arrays:

checkpoint = np.load('checkpoint.npy')
checkpoint = ...
np.save('checkpoint.npy', checkpoint)

Finally, we can load the weights manually as follow after having built the graph:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    checkpoint = np.load('checkpoint.npy').item()
    for key, data in checkpoint.iteritems():
        var_scope = ... # to be extracted from key
        var_name = ...  # 
        with tf.variable_scope(var_scope, reuse=True):
            var = tf.get_variable(var_name)
            sess.run(var.assign(data))

If there is a more straightforward approach, I'm all ears!

Upvotes: 1

Related Questions