user2368505
user2368505

Reputation: 426

Load (or combine) several pretrained checkpoints with tf.estimator.WarmStartSettings

I want to use pretrained weights for 2 parts of my model. I have 2 checkpoints from different models, from which I can load only one into my main model with tf.estimator.WarmStart as I'm using the estimator architecture.

tf.WarmStartSettings(ckpt_to_initialize_from=X)

from the doc:

Either the directory or a specific checkpoint can be provided (in the case of the former, the latest checkpoint will be used).

I can't see how I can add an additional checkpoint. Maybe there is a way to load the weights from both checkpoint into one and load that one?

Upvotes: 1

Views: 896

Answers (1)

Sharky
Sharky

Reputation: 4533

You can use init_from_checkpoint.

First, define assignment map:

dir = 'path_to_checkpoint_files'
vars_to_load = [i[0] for i in tf.train.list_variables(dir)]

This creates a list of all variables in checkpoints

assignment_map = {variable.op.name: variable for variable in tf.global_variables() if variable.op.name in vars_to_load}

And this creates a dict that has variables from current graph as key and variables from checkpoint as values

tf.train.init_from_checkpoint(dir, assignment_map)

This function is placed inside estimator's model_fn. It will override standard variable initialization.

Upvotes: 1

Related Questions