Amor Fati
Amor Fati

Reputation: 367

How to grab one tensor from an existing model and use it in another one?

What I want to do is to grab some weights and biases from an existing trained model, and then use them in my customized op (model or graph).

I can restore model with:

# Create context
with tf.Graph().as_default(), tf.Session() as sess:
    # Create model
    with tf.variable_scope('train'):
        train_model = MyModel(some_args)

And then grab tensor:

latest_ckpt = tf.train.latest_checkpoint(path)
if latest_ckpt:
    saver.restore(sess, latest_ckpt)
weight = tf.get_default_graph().get_tensor_by_name("example:0")

My question is, if I want to use that weight in another context (model or graph), how to safely copy its value to the new graph, e.g.:

with self.test_session(use_gpu=True, graph=ops.Graph()) as sess:
    with vs.variable_scope("test", initializer=initializer):
        # How can I make it possible?
        w = tf.get_variable('name', initializer=weight)

Any help is welcome, thank you so much.


Thanks @Sorin for the inspiration, I found a simple and clean way to do this:

z = graph.get_tensor_by_name('prefix/NN/W1:0')

with tf.Session(graph=graph) as sess:
    z_value = sess.run(z)

with tf.Graph().as_default() as new_graph, tf.Session(graph=new_graph) as sess:
    w = tf.get_variable('w', initializer=z_value)

Upvotes: 0

Views: 62

Answers (1)

Sorin
Sorin

Reputation: 11968

The hacky way is to use tf.assign to assign the weight to the variable you want (make sure it only happens once at the begining, and not every iteration, otherwise the model won't be able to adjust those weights).

The slightly less hacky way is to load the graph and session of the trained model and modify the graph to add the operations you need. This will make the graph a bit more messy since you also have the entire graph of the original model, but it's a bit cleaner since you can depend directly on the operations instead of the weights (that is if the original model was doing a sigmoid activation, this will copy the activation as well). The unused parts of the graph will be automatically pruned by tensorflow.

The clean way to do it is to use www.tenforflow.com/hub . It's a library that allows you to define parts of the graph as modules that you can export and import into any graph. This will handle all dependencies and configuration and also gives you nice controls over the training (i.e. if you want to freeze the weights, or delay the training for some number of iterations, etc.)

Upvotes: 1

Related Questions