Reputation: 119
I'm currently developing a model using Keras
+ Tensorflow
for calculating sentence similarity based on the STS benchmark (http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark). How I did it was that I first created a pre-trained model that converts a list of word embedding vectors into a single sentence embedding vector. Now, what I want to do is incorporate this pre-trained model to a new model which uses this model to transform the input. The following is the code for that new model.
sentence_encoder = load_model('path/to/model')
input1 = Input(shape=(30, 300), dtype='float32') # 30 words, 300 dim embedding
input2 = Input(shape=(30, 300), dtype='float32')
x1 = sentence_encoder(input1)
x2 = sentence_encoder(input2)
abs_diff = Lambda(lambda x: abs(x[0] - x[1]))([x1, x2])
x = Dense(300, activation='relu', kernel_initializer='he_uniform')(abs_diff)
result = Dense(1, activation='sigmoid')(x)
model = Model([input1, input2], result)
model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
model.fit(...)
When I run this, a model is produced and is correctly done so. What I want to know however, is if the sentence_encoder
gets trained along with this new model or do its weights remain unchanged? If possible, I'd like the sentence_encoder
's weights to be influenced by the training of this new model. If this doesn't achieve that, how do I go about doing it?
Thank you in advance!
Upvotes: 1
Views: 654
Reputation: 300559
If you don't freeze the pre-trained model layers, they will get re-computed during training. That's often not what you want.
It's more common to freeze all but the last few of the pre-trained layers, then add your layers on top and train just that part of the model.
In Keras, you freeze a network by setting its trainable
attribute to false
:
sentence_encoder.trainable = False
To freeze a subset of the layers in a model:
sentence_encoder.trainable = True
is_trainable = False
for layer in sentence_encoder.layers:
if layer.name == 'last layer name':
is_trainable = True
if is_trainable:
layer.trainable = True
else:
layer.trainable = False
Upvotes: 1