Kevin
Kevin

Reputation: 3239

Keras - Proper way to extract weights from a nested model

I have a nested model which has an input layer, and has some final dense layers before the output. Here is the code for it:

image_input = Input(shape, name='image_input')
x = DenseNet121(input_shape=shape, include_top=False, weights=None,backend=keras.backend,
layers=keras.layers,
models=keras.models,
utils=keras.utils)(image_input)
x = GlobalAveragePooling2D(name='avg_pool')(x)
x = Dense(1024, activation='relu', name='dense_layer1_image')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)        
x = Dense(512, activation='relu', name='dense_layer2_image')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
output = Dense(num_class, activation='softmax', name='image_output')(x)
classificationModel = Model(inputs=[image_input], outputs=[output])

Now If say I wanted to extract the densenets weights from this model and perform transfer learning to another larger model which also has the same densenet model nested but also has an some other layers after the dense net such as:

image_input = Input(shape, name='image_input')
x = DenseNet121(input_shape=shape, include_top=False, weights=None,backend=keras.backend,
layers=keras.layers,
models=keras.models,
utils=keras.utils)(image_input)
x = GlobalAveragePooling2D(name='avg_pool')(x)
x = Dense(1024, activation='relu', name='dense_layer1_image')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)        
x = Dense(512, activation='relu', name='dense_layer2_image')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
x = Dense(256, activation='relu', name='dense_layer3_image')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
output = Dense(num_class, activation='sigmoid', name='image_output')(x)
classificationModel = Model(inputs=[image_input], outputs=[output])

Would I need to just do: modelB.load_weights(<weights.hdf5>, by_name=True)? Also should I name the internal densenet? and if so how?

Upvotes: 0

Views: 722

Answers (2)

Daniel M&#246;ller
Daniel M&#246;ller

Reputation: 86600

You can, before using the nested model, have it into a variable. It gets a lot easier to do everything:

densenet = DenseNet121(input_shape=shape, include_top=False, 
                       weights=None,backend=keras.backend,
                       layers=keras.layers,
                       models=keras.models,
                       utils=keras.utils)

image_input = Input(shape, name='image_input')
x = densenet(image_input)
x = GlobalAveragePooling2D(name='avg_pool')(x)
......

Now it's super simple to:

weights = densenet.get_weights()
another_densenet.set_weights(weights)

The loaded file

You can also print a model.summary() of your loaded model. The dense net will be the first or second layer (you must check this).

You can then get it like densenet = loaded_model.layers[i].

You can then transfer these weights to the new dense net, both with the method in the previous answer and with the new_model.layers[i].set_weights(densenet.get_weights())

Upvotes: 1

Jake Tae
Jake Tae

Reputation: 1741

Perhaps the easiest way to go about this is to use the model you have trained itself without trying to load the model weights. Say you have trained the initial model (copied and pasted from the provided source code with minimal edits to variable name):

image_input = Input(shape, name='image_input')
# ... intermediery layers elided
x = BatchNormalization()(x)
output = Dropout(0.5)(x)
model_output = Dense(num_class, activation='softmax', name='image_output')(output)
smaller_model = Model(inputs=[image_input], outputs=[model_output])

To use the trained weights of this model for a larger model, we can simply declare another model that uses the trained weights, then use that newly defined model as a component of the larger model.

new_model = Model(image_input, output) # Model that uses trained weights

main_input = Input(shape, name='main_input')
x = new_model(main_input)
x = Dense(256, activation='relu', name='dense_layer3_image')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
output = Dense(num_class, activation='sigmoid', name='image_output')(x)
final_model = Model(inputs=[main_input], outputs=[output])

If anything is unclear, I'd be more than happy to elaborate.

Upvotes: 1

Related Questions