Reputation: 155
My NN has to learn image similarity with a custom triplet loss. The positive image is similar to the anchor, while the negative is not.
My task is to predict whether the second image or the third image of an unseen triplet is more similar to the anchor or not.
The triplets are given for both train and test sets in the task, so I did not have to mine them or randomly generate them: they are fixed in my task.
---> Idea: To improve my model, I try to use feature learning with Xception layers frozen and adding a Dense layer on top.
Problem:
When training the below model with Xception layers frozen, after 1-2 epochs it learns to just set all positive images to a very low distance to the anchor and all negative images to a very high distance. Hence, the 100% val accuracy.
I immediately thought of overfitting but I only have one fully connected layer that I train? How can I combat this? Or is my triplet loss somehow wrongly defined?
I dont use data augmentation so could that potentially help?
Somehow this happens only when using a pretrained model. When I use a simple model I get realistic accuracy...
What am I missing here?
My triplet loss:
def triplet_loss(y_true, y_pred, alpha = 0.4):
"""
Implementation of the triplet loss function
Arguments:
y_true -- true labels, required when you define a loss in Keras, you don't need it in this function.
y_pred -- python list containing three objects:
anchor -- the encodings for the anchor data
positive -- the encodings for the positive data (similar to anchor)
negative -- the encodings for the negative data (different from anchor)
Returns:
loss -- real number, value of the loss
"""
total_length = y_pred.shape.as_list()[-1]
anchor = y_pred[:,0:int(total_length*1/3)]
positive = y_pred[:,int(total_length*1/3):int(total_length*2/3)]
negative = y_pred[:,int(total_length*2/3):int(total_length*3/3)]
# distance between the anchor and the positive
pos_dist = K.sum(K.square(anchor-positive),axis=1)
# distance between the anchor and the negative
neg_dist = K.sum(K.square(anchor-negative),axis=1)
# compute loss
basic_loss = pos_dist-neg_dist+alpha
loss = K.maximum(basic_loss,0.0)
return loss
Then my model:
def baseline_model():
input_1 = Input(shape=(256, 256, 3))
input_2 = Input(shape=(256, 256, 3))
input_3 = Input(shape=(256, 256, 3))
pretrained_model = Xception(include_top=False, weights="imagenet")
for layer in pretrained_model.layers:
layer.trainable = False
x1 = pretrained_model(input_1)
x2 = pretrained_model(input_2)
x3 = pretrained_model(input_3)
x1 = Flatten(name='flatten1')(x1)
x2 = Flatten(name='flatten2')(x2)
x3 = Flatten(name='flatten3')(x3)
x1 = Dense(128, activation=None,kernel_regularizer=l2(0.01))(x1)
x2 = Dense(128, activation=None,kernel_regularizer=l2(0.01))(x2)
x3 = Dense(128, activation=None,kernel_regularizer=l2(0.01))(x3)
x1 = Lambda(lambda x: K.l2_normalize(x,axis=-1))(x1)
x2 = Lambda(lambda x: K.l2_normalize(x,axis=-1))(x2)
x3 = Lambda(lambda x: K.l2_normalize(x,axis=-1))(x3)
concat_vector = concatenate([x1, x2, x3], axis=-1, name='concat')
model = Model([input_1, input_2, input_3], concat_vector)
model.compile(loss=triplet_loss, optimizer=Adam(0.00001), metrics=[accuracy])
model.summary()
return model
Fitting my model:
model.fit(
gen(X_train,batch_size=batch_size),
steps_per_epoch=13281 // batch_size,
epochs=10,
validation_data=gen(X_val,batch_size=batch_size),
validation_steps=1666 // batch_size,
verbose=1,
callbacks=callbacks_list
)
model.save_weights('try_6.h5')
Upvotes: 0
Views: 837
Reputation: 478
Please note that you use different Dense layers for each input (you define 3 different Dense layers. each time you create a new Dense object it generate a new layer, with new parameters, independent of the previous layers you created). If the input is consistent, meaning input 1 is always the anchor, input 2 is always the positive, and input 3 is always the negative - it will be super easy for the model to overfit. What you should probably do is use only a single Dense layer for all 3 inputs.
For example, based on your code you can define the model like this:
pretrained_model = Xception(include_top=False, weights="imagenet")
for layer in pretrained_model.layers:
layer.trainable = False
general_input = Input(shape=(256, 256, 3))
x = pretrained_model(general_input)
x = Flatten()(x)
x = Dense(128, activation=None,kernel_regularizer=l2(0.01))(x)
base_model = Model([general_input], [x])
input_1 = Input(shape=(256, 256, 3))
input_2 = Input(shape=(256, 256, 3))
input_3 = Input(shape=(256, 256, 3))
x1 = base_model(input_1)
x2 = base_model(input_2)
x3 = base_model(input_3)
# ... continue with your code - normalize, concat, etc.
Upvotes: 1