Jonathan Roy
Jonathan Roy

Reputation: 441

keras access layer parameter of pre-trained model to freeze

I saved an LSTM with multiple layers. Now, I want to load it and just fine-tune the last LSTM layer. How can I target this layer and change its parameters?

Example of a simple model trained and saved:

model = Sequential()
# first layer  #neurons 
model.add(LSTM(100, return_sequences=True, input_shape=(X.shape[1], 
X.shape[2])))
model.add(LSTM(50, return_sequences=True))
model.add(LSTM(25))
model.add(Dense(1))
model.compile(loss='mae', optimizer='adam')

I can load and retrain it but I can't find a way to target specific layer and freeze all the other layers.

Upvotes: 1

Views: 2981

Answers (2)

today
today

Reputation: 33420

If you have previously built and saved the model and now want to load it and fine-tune only the last LSTM layer, then you need to set the other layers' trainable property to False. First, find the name of the layer (or index of the layer by counting from zero starting from the top) by using model.summary() method. For example this is the output produced for one of my models:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_10 (InputLayer)        (None, 400, 16)           0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 400, 32)           4128      
_________________________________________________________________
lstm_2 (LSTM)                (None, 32)                8320      
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 33        
=================================================================
Total params: 12,481
Trainable params: 12,481
Non-trainable params: 0
_________________________________________________________________

Then set the trainable parameters of all the layers except the LSTM layer to False.

Approach 1:

for layer in model.layers:
    if layer.name != `lstm_2`
        layer.trainable = False

Approach 2:

for layer in model.layers:
    layer.trainable = False

model.layers[2].trainable = True  # set lstm to be trainable

# to make sure 2 is the index of the layer
print(model.layers[2].name)    # prints 'lstm_2'

Don't forget to compile the model again to apply these changes.

Upvotes: 0

KonstantinosKokos
KonstantinosKokos

Reputation: 3473

An easy solution would be to name each layer, i.e.

model.add(LSTM(50, return_sequences=True, name='2nd_lstm'))

Then, upon loading the model you can iterate over the layers and freeze the ones matching a name condition:

for layer in model.layers:
    if layer.name == '2nd_lstm':
        layer.trainable = False

Then you need to recompile your model for the changes to take effect, and afterwards you may resume training as usual.

Upvotes: 2

Related Questions