DsCpp
DsCpp

Reputation: 2489

load the same weight to several variables in the new graph

I want to load the same variable in the pretrained model to several variables in the new model

v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver(v1)

with tf.Session() as sess:
  sess.run(init_op)
  sess.run(v1+1)
  save_path = saver.save(sess, "/tmp/model.ckpt")

and afterwords

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[3])

# Add ops to save and restore all the variables.
saver = tf.train.Saver({"v1" : v1,"v1":v2})

with tf.Session() as sess:
  saver.restore(sess, "/tmp/model.ckpt")

I.e I want both variables would be initialized from the v1 variable from the previous model.
The following example crashes as it says the graphs are different.

Upvotes: 1

Views: 42

Answers (2)

DsCpp
DsCpp

Reputation: 2489

Here's another method, iterating the variables from the previous graph:

def load_pretrained(sess):
    checkpoint_path = 'pretrainedmodel.ckpt'

    vars_to_load = [var for var in tf.get_collection(tf.GraphKeys.VARIABLES) if
                    ("some_scope" in var.op.name)]

    assign_ops = []
    reader = tf.contrib.framework.load_checkpoint(checkpoint_path)

    for var in vars_to_load:
        for name,shape in tf.contrib.framework.list_variables(checkpoint_path):
            if(var.op.name ~some regex comperison~ name):
                assign_ops.append(tf.assign(var,reader.get_tensor(name)))
                break


    sess.run(assign_ops)

Upvotes: 0

Vlad
Vlad

Reputation: 8585

Evaluate the assigned value of the variable from the original graph and then initialize new variables from new graph with this value:

import tensorflow as tf

with tf.Graph().as_default():
    # the variable from the original graph
    v0 = tf.Variable(tf.random_normal([2, 2]))

with tf.Session(graph=v0.graph) as sess:
    sess.run(v0.initializer)
    init_val = v0.eval() # <-- evaluate the assigned value
    print('original graph:')
    print(init_val)
# original graph:
# [[-1.7466899   1.1560178 ]
#  [-0.46535382  1.7059366 ]]

# variables from new graph
with tf.Graph().as_default():
    v1 = tf.Variable(init_val) # <-- variable from new graph
    v2 = tf.Variable(init_val) # <-- variable from new graph

with tf.Session(graph=v1.graph) as sess:
    sess.run([v.initializer for v in [v1, v2]])
    print('new graph:')
    print(v1.eval())
    print(v2.eval())
# new graph:
# [[-1.7466899   1.1560178 ]
#  [-0.46535382  1.7059366 ]]
# [[-1.7466899   1.1560178 ]
#  [-0.46535382  1.7059366 ]]

Upvotes: 1

Related Questions