Reputation: 321
In Tensorflow, my model is based on a pre-trained model, and I added a few more variables and remove some in the pre-trained model. When I restore the variables from the checkpoint file, I have to explicitly specify all variables I added to the graph that need to be excluded. For example, I did
exclude = # explicitly list all variables to exclude
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
Is there a simpler way to do this? Namely, as long as a variable is not in checkpoint, then don't try to restore.
Upvotes: 5
Views: 7313
Reputation: 7140
You should first find out all those variable that are useful(meaning also in your graph) and then add the joint set of the intersection of the two from the checkpoint rather than all from it.
variables_can_be_restored = list(set(tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)).intersection(tf.train.list_variables(checkpoint_dir)))
then restore it after defining a saver like this:
temp_saver = tf.train.Saver(variables_can_be_restored)
ckpt_state = tf.train.get_checkpoint_state(checkpoint_dir, lastest_filename)
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path)
temp_saver.restore(sess, ckpt_state.model_checkpoint_path)
Upvotes: 3
Reputation: 21
This is more full answer, that works for not-distributed setting:
from tensorflow.contrib.framework.python.framework import checkpoint_utils
slim = tf.contrib.slim
def scan_checkpoint_for_vars(checkpoint_path, vars_to_check):
check_var_list = checkpoint_utils.list_variables(checkpoint_path)
check_var_list = [x[0] for x in check_var_list]
check_var_set = set(check_var_list)
vars_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] in check_var_set]
vars_not_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] not in check_var_set]
return vars_in_checkpoint, vars_not_in_checkpoint
def create_easy_going_scaffold(vars_in_checkpoint, vars_not_in_checkpoint):
model_ready_for_local_init_op = tf.report_uninitialized_variables(var_list = vars_in_checkpoint)
model_init_vars_not_in_checkpoint = tf.variables_initializer(vars_not_in_checkpoint)
restoration_saver = tf.train.Saver(vars_in_checkpoint)
eg_scaffold = tf.train.Scaffold(saver=restoration_saver,
ready_for_local_init_op = model_ready_for_local_init_op,
local_init_op = model_init_vars_not_in_checkpoint)
return eg_scaffold
all_vars = slim.get_variables()
ckpoint_file = tf.train.latest_checkpoint(output_chkpt_dir)
vars_in_checkpoint, vars_not_in_checkpoint = scan_checkpoint_for_vars(ckpoint_file, all_vars)
is_checkpoint_complete = len(vars_not_in_checkpoint) == 0
# Create session that can handle current checkpoint
if (is_checkpoint_complete):
# Checkpoint is full - all variables can be found there
print('Using normal session')
sess = tf.train.MonitoredTrainingSession(checkpoint_dir = output_chkpt_dir,
save_checkpoint_secs = save_checkpoint_secs,
save_summaries_secs = save_summaries_secs)
else:
# Checkpoint is partial - some variables need to be initialized
print('Using easy going session')
eg_scaffold = create_easy_going_scaffold(vars_in_checkpoint, vars_not_in_checkpoint)
# Save all variables to next checkpoint
saver = tf.train.Saver()
hooks = [tf.train.CheckpointSaverHook(checkpoint_dir = output_chkpt_dir,
save_secs = save_checkpoint_secs,
saver = saver)]
# Such session is a little slower during the first iteration
sess = tf.train.MonitoredTrainingSession(checkpoint_dir = output_chkpt_dir,
scaffold = eg_scaffold,
hooks = hooks,
save_summaries_secs = save_summaries_secs,
save_checkpoint_secs = None)
with sess:
.....
Upvotes: 0
Reputation: 1665
The only thing that you can do is firstly having the same model as in the checkpoint, secondly restoring the checkpoint values to the same model. After restoring the variables for the same model, you can add new layers, delete existing layers or change the weights of the layers.
But there is an important point that you need to be careful. After added new layers you need to initialize them. If you use tf.global_variables_initializer()
, you will lose the values of reloaded layers. So you should only initialize the uninitialized weights, you can use following function for this.
def initialize_uninitialized(sess):
global_vars = tf.global_variables()
is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
# for i in not_initialized_vars: # only for testing
# print(i.name)
if len(not_initialized_vars):
sess.run(tf.variables_initializer(not_initialized_vars))
Upvotes: 1