Reputation: 17468
I have two models (A and B) with the same architecture, both A and B have the same variables names and model settings, for example
['A1\B1\C1', 'A2\B2\C2', 'A3\B3\C3']
I have got checkpoint files for A and B, and I want to combine ['A1\B1\C1', 'A2\B2\C2']
in A with 'A3\B3\C3'
in B int to a checkpoint file and restore it to model A. How can I do that with saver.restor()
?
Upvotes: 1
Views: 1725
Reputation: 17468
Answering my question on my own.
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
def load_weights(ckpt_path, prefix_list):
vars_weights = {}
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
for _pref in prefix_list:
if key.startswith(_pref):
vars_weights[key+':0'] = reader.get_tensor(key)
return vars_weights
# Build model
...
# Init variables
sess.run(tf.global_variables_initializer())
# Restore model
saver.restore(sess, load_dir_A)
prefix = ['A3\B3\C3']
# Get weights from ckpt of B
B_weights = load_weights(load_dir_B, prefix)
# Assign weights from B to A
assign_ops = [tf.assign(tf.get_default_graph().get_tensor_by_name(_name, _value)
for _name, _value in opponent_weights.items()]
sess.run(assign_ops)
Upvotes: 1
Reputation: 4533
You can do it with init_from_checkpoint. After defining current model, create assignment map.
dir = 'path_to_A_and_B_checkpoint_files'
vars_to_load = [i[0] for i in tf.train.list_variables(dir)]
assignment_map = {variable.op.name: variable for variable in tf.global_variables() if variable.op.name in vars_to_restore}
This creates a dict that has variables from current graph as key and variables from checkpoints as values
tf.train.init_from_checkpoint(dir, assignment_map)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#do_usual_stuff
This function is placed before declaring a session and substitutes saver.restore
Upvotes: 1