daniel451
daniel451

Reputation: 11002

TensorFlow: load checkpoint, but only parts of it (convolutional layers)

Is it possible to only load specific layers (convolutional layers) out of one checkpoint file?

I've trained some CNNs fully-supervised and saved my progress (I'm doing object localization). To do auto-labelling I thought of building a weakly-supervised CNNs out of my current model...but since the weakly-supervised version has different fully-connected layers, I would like to select only the convolutional filters of my TensorFlow checkpoint file.

Of course I could manually save the weights of the corresponding layers, but due to the fact that they're already included in TensorFlow's checkpoint file I would like to extract them there, in order to have one single storing file.

Upvotes: 4

Views: 2056

Answers (1)

Grwlf
Grwlf

Reputation: 956

TensorFlow 2.1 has many different public facilities for loading checkpoints (model.save, Checkpoint, saved_model, etc), but to the best of my knowledge, none of them has filtering API. So, let me suggest a snippet for hard cases which uses tooling from the TF2.1 internal development tests.

checkpoint_filename = '/path/to/our/weird/checkpoint.ckpt'
model = tf.keras.Model( ... ) # TF2.0 Model to initialize with the above checkpoint
variables_to_load = [ ... ] # List of model weight names to update.

from tensorflow.python.training.checkpoint_utils import load_checkpoint, list_variables

reader = load_checkpoint(checkpoint_filename)
for w in model.weights:
    name=w.name.split(':')[0] # See (b/29227106)
    if name in variables_to_load:
        print(f"Updating {name}")
        w.assign(reader.get_tensor(
            # (Optional) Handle variable renaming 
            {'/var_name1/in/model':'/var_name1/in/checkpoint',
             '/var_name2/in/model':'/var_name2/in/checkpoint',
             # ...  and so on
             }.get(name,name)))

Note: model.weights and list_variables may help to inspect variables in Model and in the checkpoint

Note also, that this method will not restore model's optimizer state.

Upvotes: 2

Related Questions