Reputation: 11
I've been trying to research model/weight saving for a while, but I still can't fully grasp it. I feel what I'd like to do should be simple enough, but I've not found a solution.
The final goal is to do transfer laerning with a collection of pretrained networks. I write my models/layers as classes, so class method(s) for saving the weights and restoring would be ideal.
Example: If I have a graph, features > A > B > labels, where A and B are sub-networks, I'd like to save and/or restore weights for these sections. Say I already have the weights for A trained, but the variable scope is now different, how would I restore the weights I've trained for A from a different training session? At the end of training this new graph i'd like 1 directory for my new A weights, 1 directory for my new B weights, and 1 directory for the full graph (I can handle the full graph bit).
It's very possible I keep overlooking the solution, but model saving is so poorly documented.
Hope I've explained the scenario well.
Upvotes: 1
Views: 918
Reputation: 4533
You can do this with tf.train.init_from_checkpoint
Define your model
def model_fn():
with tf.variable_scope('One'):
layer = any_tf_layer
with tf.variable_scope('Two'):
layer = any_tf_layer
Output variable names in checkpoint file
vars = [i[0] for i in tf.train.list_variables(ckpt_file)]
Then you can create assignment map to load only variables, defined in your model. You can also assign new names to restored variables
map = {variable.op.name: variable for variable in tf.global_variables() if variable.op.name in vars}
This line is placed before session or outside model function for Estimator API
tf.train.init_from_checkpoint(ckpt_file, map)
https://www.tensorflow.org/api_docs/python/tf/train/init_from_checkpoint
You also can do it with tf.train.Saver
First you need to know the names of variables
vars_dict = {}
for var_current in tf.global_variables():
print(var_current)
print(var_current.op.name) # this gets only name
for var_ckpt in tf.train.list_variables(ckpt):
print(var_ckpt[0]) this gets only name
When you know exact names of all variables you can assign whatever value you need, provided variables have same shape and dtype. So to get a dictionary
vars_dict[var_ckpt[0]) = tf.get_variable(var_current.op.name, shape) # remember to specify shape, you can always get it from var_current
saver = tf.train.Saver(vars_dict)
Take a look at my other answer to similar question How to restore pretrained checkpoint for current model in Tensorflow?
Upvotes: 1