DsCpp
DsCpp

Reputation: 2489

Restore tf variables in a different graph

I want to use my pretrained separable convolution (which is a part of a bigger module) in another separable convolution in a other model.
In the trained module I tried

with tf.variable_scope('sep_conv_ker' + str(input_shape[-1])):
            sep_conv2d = tf.reshape(
            tf.layers.separable_conv2d(inputs_flatten,input_shape[-1] , 
            [1,input_shape[-2]]
            trainable=trainable),
            [inputs_flatten.shape[0],1,input_shape[-1],INNER_LAYER_WIDTH]) 

and

        all_variables = tf.trainable_variables()
        scope1_variables = tf.contrib.framework.filter_variables(all_variables, include_patterns=['sep_conv_ker'])
        sep_conv_weights_saver = tf.train.Saver(scope1_variables, sharded=True, max_to_keep=20)

Inside sess.run

sep_conv_weights_saver.save(sess,os.path.join(LOG_DIR + MODEL_SPEC_LOG_DIR,
                                                              "init_weights",MODEL_SPEC_SUFFIX + 'epoch_' + str(epoch) + '.ckpt'))

But I cannot understand when and how should I load the weights to the separable convolution in the other module, it has different name, and different scope,
Furthermore, as I'm using a defined tf.layer does it mean I need to access each individual weight in the new graph and assign it?

My current solution doesn't work and I think that the weights are being initialized after the assignment somehow
Furthermore, loading a whole new graph just for few weights seems weird, isn't it?

        ###IN THE OLD GRAPH###
        all_variables = tf.trainable_variables()
        scope1_variables = tf.contrib.framework.filter_variables(all_variables, include_patterns=['sep_conv_ker'])
        vars = dict((var.op.name.split("/")[-1] + str(idx), var) for idx,var in enumerate(scope1_variables))
        sep_conv_weights_saver = tf.train.Saver(vars, sharded=True, max_to_keep=20)

In the new graph is a function that basiclly takes the variables from the old graph and assigning them, loading the meta_graph is redundant

def load_pretrained(sess):
    sep_conv2d_vars = [var for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if ("sep_conv_ker" in var.op.name)]
    var_dict = dict((var.op.name.split("/")[-1] + str(idx), var) for idx, var in enumerate(sep_conv2d_vars))
    new_saver = tf.train.import_meta_graph(
        tf.train.latest_checkpoint('log/train/sep_conv_ker/global_neighbors40/init_weights') + '.meta')
    # saver = tf.train.Saver(var_list=var_dict)
    new_saver.restore(sess,
                      tf.train.latest_checkpoint('log/train/sep_conv_ker/global_neighbors40/init_weights'))

    graph = tf.get_default_graph()
    sep_conv2d_trained = dict(("".join(var.op.name.split("/")[-2:]),var) for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if ("sep_conv_ker_init" in var.op.name))
    for var in sep_conv2d_vars:
        tf.assign(var,sep_conv2d_trained["".join(var.op.name.split("/")[-2:])])

Upvotes: 0

Views: 666

Answers (1)

Jindřich
Jindřich

Reputation: 11220

You need to make sure that the variables have the same in the variable file and in the graph where you load the variables. You can write a script that will convert the variables names.

  1. With tf.contrib.framework.list_variables(ckpt), you can find out what variables of what shapes you have in the checkpoint and create respective variables with the new names (I believe, you can write a regex that will fix the names) and correct shape.
  2. Then you load the original variables with tf.contrib.framework.load_checkpoint(ckpt) assign ops tf.assign(var, loaded) that will assigning the variables with new names with the saved values.
  3. Runn the assign ops in a session.
  4. Save the new variables.

Minimum example:

Original model (variables in scope "regression"):

import tensorflow as tf

x = tf.placeholder(tf.float32, [None, 3]) 
regression = tf.layers.dense(x, 1, name="regression")

session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.trainable_variables())

saver.save(session, './model')

Renaming script:

import tensorflow as tf

assign_ops = []
reader = tf.contrib.framework.load_checkpoint("./model")
for name, shape in tf.contrib.framework.list_variables("./model"):
    new_name = name.replace("regression/", "foo/bar/")
    new_var = tf.get_variable(new_name, shape)
    assign_ops.append(tf.assign(new_var, reader.get_tensor(name)))

session = tf.Session()
saver = tf.train.Saver(tf.trainable_variables())

session.run(assign_ops)
saver.save(session, './model-renamed')

Model where you load the renamed variables (the same variables in score "foo/bar"):

import tensorflow as tf

with tf.variable_scope("foo"):
    x = tf.placeholder(tf.float32, [None, 3]) 
    regression = tf.layers.dense(x, 1, name="bar")

session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.trainable_variables())

saver.restore(session, './model-renamed')

Upvotes: 1

Related Questions