Reputation: 426
I want to use pretrained weights for 2 parts of my model. I have 2 checkpoints from different models, from which I can load only one into my main model with tf.estimator.WarmStart as I'm using the estimator architecture.
tf.WarmStartSettings(ckpt_to_initialize_from=X)
Either the directory or a specific checkpoint can be provided (in the case of the former, the latest checkpoint will be used).
I can't see how I can add an additional checkpoint. Maybe there is a way to load the weights from both checkpoint into one and load that one?
Upvotes: 1
Views: 896
Reputation: 4533
You can use init_from_checkpoint
.
First, define assignment map:
dir = 'path_to_checkpoint_files'
vars_to_load = [i[0] for i in tf.train.list_variables(dir)]
This creates a list of all variables in checkpoints
assignment_map = {variable.op.name: variable for variable in tf.global_variables() if variable.op.name in vars_to_load}
And this creates a dict that has variables from current graph as key and variables from checkpoint as values
tf.train.init_from_checkpoint(dir, assignment_map)
This function is placed inside estimator's model_fn
. It will override standard variable initialization.
Upvotes: 1