shaft
shaft

Reputation: 2229

How to apply triplet loss function in resnet50 for the purpose of deepranking

I try to create image embeddings for the purpose of deep ranking using a triplet loss function. The idea is that we can take a pretrained CNN (e.g. resnet50 or vgg16), remove the FC layers and add an L2 normalization function to retrieve unit vectors which can then be compared via a distance metric (e.g. cosine similarity). As far as I understand the predicted vectors that come out of a pretrained CNN are not optimal, but are a good start. By adding the triplet loss function we can re-train the network to keep similar pictures 'close' to each other and different pictures 'far' apart in the feature space. Inspired by this notebook , I tried to setup the following code, but I get an error ValueError: The name "conv1_pad" is used 3 times in the model. All layer names should be unique..

# Anchor, Positive and Negative are numpy arrays of size (200, 256, 256, 3), same for the test images

pic_size=256
def shared_dnn(inp):
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(3, pic_size, pic_size), 
                          input_tensor=inp)
    x = base_model.output
    x = Flatten()(x)
    x = Lambda(lambda x: K.l2_normalize(x,axis=1))(x)
    for layer in base_model.layers[15:]:
        layer.trainable = False
    return x

anchor_input = Input((3, pic_size,pic_size ), name='anchor_input')
positive_input = Input((3, pic_size,pic_size ), name='positive_input')
negative_input = Input((3, pic_size,pic_size ), name='negative_input')

encoded_anchor = shared_dnn(anchor_input)
encoded_positive = shared_dnn(positive_input)
encoded_negative = shared_dnn(negative_input)

merged_vector = concatenate([encoded_anchor, encoded_positive, encoded_negative], axis=-1, name='merged_layer')

model = Model(inputs=[anchor_input,positive_input, negative_input], outputs=merged_vector)

#ValueError: The name "conv1_pad" is used 3 times in the model. All layer names should be unique.
model.compile(loss=triplet_loss, optimizer=adam_optim)

model.fit([Anchor,Positive,Negative],
          y=Y_dummy,
          validation_data=([Anchor_test,Positive_test,Negative_test],Y_dummy2), batch_size=512, epochs=500)

I am new to keras and I am not quite sure how to solve this. The author in the link above creates his own CNN from scratch, but I would like to build it upon resnet (or vgg16). How can I configure ResNet50 to use a triplet loss function (in the link above you find also the source code for the triplet loss function).

Upvotes: 2

Views: 1871

Answers (1)

Akash Kumar
Akash Kumar

Reputation: 483

In your ResNet50 definition, you've written

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(3, pic_size, pic_size), input_tensor=inp)

Remove the input_tensor argument. Change input_shape=inp. If you're using TF backend as you mentioned the input should be (256, 256, 3), then your input should be (pic_size, pic_size, 3).

def shared_dnn(inp):
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=inp)
    x = base_model.output
    x = Flatten()(x)
    x = Lambda(lambda x: K.l2_normalize(x,axis=1))(x)
    for layer in base_model.layers[15:]:
        layer.trainable = False
    return x

img_shape=(256, 256, 3)
anchor_input = Input(img_shape, name='anchor_input')
positive_input = Input(img_shape, name='positive_input')
negative_input = Input(img_shape, name='negative_input')

encoded_anchor = shared_dnn(anchor_input)
encoded_positive = shared_dnn(positive_input)
encoded_negative = shared_dnn(negative_input)

merged_vector = concatenate([encoded_anchor, encoded_positive, encoded_negative], axis=-1, name='merged_layer')

model = Model(inputs=[anchor_input,positive_input, negative_input], outputs=merged_vector)
model.compile(loss=triplet_loss, optimizer=adam_optim)

model.fit([Anchor,Positive,Negative],
          y=Y_dummy,
          validation_data=([Anchor_test,Positive_test,Negative_test],Y_dummy2), batch_size=512, epochs=500)

The model plot is as follows: model_plot

Upvotes: 1

Related Questions