Vandenn
Vandenn

Reputation: 119

Pre-trained Model on Multiple Inputs

I'm currently developing a model using Keras + Tensorflow for calculating sentence similarity based on the STS benchmark (http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark). How I did it was that I first created a pre-trained model that converts a list of word embedding vectors into a single sentence embedding vector. Now, what I want to do is incorporate this pre-trained model to a new model which uses this model to transform the input. The following is the code for that new model.

sentence_encoder = load_model('path/to/model')

input1 = Input(shape=(30, 300), dtype='float32') # 30 words, 300 dim embedding
input2 = Input(shape=(30, 300), dtype='float32')
x1 = sentence_encoder(input1)
x2 = sentence_encoder(input2)
abs_diff = Lambda(lambda x: abs(x[0] - x[1]))([x1, x2])
x = Dense(300, activation='relu', kernel_initializer='he_uniform')(abs_diff)
result = Dense(1, activation='sigmoid')(x)

model = Model([input1, input2], result)
model.compile(loss='binary_crossentropy',
    optimizer='rmsprop',
    metrics=['accuracy'])

model.fit(...)

When I run this, a model is produced and is correctly done so. What I want to know however, is if the sentence_encoder gets trained along with this new model or do its weights remain unchanged? If possible, I'd like the sentence_encoder's weights to be influenced by the training of this new model. If this doesn't achieve that, how do I go about doing it?

Thank you in advance!

Upvotes: 1

Views: 654

Answers (1)

Mitch Wheat
Mitch Wheat

Reputation: 300559

If you don't freeze the pre-trained model layers, they will get re-computed during training. That's often not what you want.

It's more common to freeze all but the last few of the pre-trained layers, then add your layers on top and train just that part of the model.

In Keras, you freeze a network by setting its trainable attribute to false:

sentence_encoder.trainable = False

To freeze a subset of the layers in a model:

sentence_encoder.trainable = True
is_trainable = False
for layer in sentence_encoder.layers:
  if layer.name == 'last layer name':
     is_trainable = True
  if is_trainable:
    layer.trainable = True
  else:
    layer.trainable = False

Upvotes: 1

Related Questions