Reputation: 441
I saved an LSTM with multiple layers. Now, I want to load it and just fine-tune the last LSTM layer. How can I target this layer and change its parameters?
Example of a simple model trained and saved:
model = Sequential()
# first layer #neurons
model.add(LSTM(100, return_sequences=True, input_shape=(X.shape[1],
X.shape[2])))
model.add(LSTM(50, return_sequences=True))
model.add(LSTM(25))
model.add(Dense(1))
model.compile(loss='mae', optimizer='adam')
I can load and retrain it but I can't find a way to target specific layer and freeze all the other layers.
Upvotes: 1
Views: 2981
Reputation: 33420
If you have previously built and saved the model and now want to load it and fine-tune only the last LSTM layer, then you need to set the other layers' trainable
property to False
. First, find the name of the layer (or index of the layer by counting from zero starting from the top) by using model.summary()
method. For example this is the output produced for one of my models:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_10 (InputLayer) (None, 400, 16) 0
_________________________________________________________________
conv1d_2 (Conv1D) (None, 400, 32) 4128
_________________________________________________________________
lstm_2 (LSTM) (None, 32) 8320
_________________________________________________________________
dense_2 (Dense) (None, 1) 33
=================================================================
Total params: 12,481
Trainable params: 12,481
Non-trainable params: 0
_________________________________________________________________
Then set the trainable parameters of all the layers except the LSTM layer to False
.
Approach 1:
for layer in model.layers:
if layer.name != `lstm_2`
layer.trainable = False
Approach 2:
for layer in model.layers:
layer.trainable = False
model.layers[2].trainable = True # set lstm to be trainable
# to make sure 2 is the index of the layer
print(model.layers[2].name) # prints 'lstm_2'
Don't forget to compile the model again to apply these changes.
Upvotes: 0
Reputation: 3473
An easy solution would be to name each layer, i.e.
model.add(LSTM(50, return_sequences=True, name='2nd_lstm'))
Then, upon loading the model you can iterate over the layers and freeze the ones matching a name condition:
for layer in model.layers:
if layer.name == '2nd_lstm':
layer.trainable = False
Then you need to recompile your model for the changes to take effect, and afterwards you may resume training as usual.
Upvotes: 2