Deshwal
Deshwal

Reputation: 4152

How to apply Triplet Loss for a ResNet50 based Siamese Network in Keras or Tf 2

I have a ResNet based siamese network which uses the idea that you try to minimize the l-2 distance between 2 images and then apply a sigmoid so that it gives you {0:'same',1:'different'} output and based on how far the prediction is, you just flow the gradients back to network but there is a problem that updation of gradients is too little as we're changing the distance between {0,1} so I thought of using the same architecture but based on Triplet Loss.

    I1 = Input(shape=image_shape)
    I2 = Input(shape=image_shape)

    res_m_1 = ResNet50(include_top=False, weights='imagenet', input_tensor=I1, pooling='avg')
    res_m_2 = ResNet50(include_top=False, weights='imagenet', input_tensor=I2, pooling='avg')

    x1 = res_m_1.output
    x2 = res_m_2.output
    # x = Flatten()(x) or use this one if not using any pooling layer

    distance = Lambda( lambda tensors : K.abs( tensors[0] - tensors[1] )) ([x1,x2] )
    final_output = Dense(1,activation='sigmoid')(distance)

    siamese_model = Model(inputs=[I1,I2], outputs=final_output)
 
    siamese_model.compile(loss='binary_crossentropy',optimizer=Adam(),metrics['acc'])

    siamese_model.fit_generator(train_gen,steps_per_epoch=1000,epochs=10,validation_data=validation_data)

So how can I change it to use the Triplet Loss function? What adjustments should be done here in order to get this done? One change will be that I'll have to calculate

res_m_3 = ResNet50(include_top=False, weights='imagenet', input_tensor=I2, pooling='avg')
x3 = res_m_3.output

One thing found in tf docs is triplet-semi-hard-loss and is given as:

tfa.losses.TripletSemiHardLoss()

As shown in the paper, the best results are from triplets known as "Semi-Hard". These are defined as triplets where the negative is farther from the anchor than the positive, but still produces a positive loss. To efficiently find these triplets we utilize online learning and only train from the Semi-Hard examples in each batch.

Another implementation of Triplet Loss which I found on Kaggle is: Triplet Loss Keras

Which one should I use and most importantly, HOW?

P.S: People also use something like: x = Lambda(lambda x: K.l2_normalize(x,axis=1))(x) after model.output. Why is that? What is this doing?

Upvotes: 1

Views: 1699

Answers (1)

Mr. For Example
Mr. For Example

Reputation: 4313

Following this answer of mine, and with role of TripletSemiHardLoss in mind, we could do following:

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
from tensorflow.keras import models, layers

BATCH_SIZE = 32
LATENT_DEM = 128

def _normalize_img(img, label):
    img = tf.cast(img, tf.float32) / 255.
    return (img, label)

train_dataset, test_dataset = tfds.load(name="mnist", split=['train', 'test'], as_supervised=True)

# Build your input pipelines
train_dataset = train_dataset.shuffle(1024).batch(BATCH_SIZE)
train_dataset = train_dataset.map(_normalize_img)

test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.map(_normalize_img)

inputs = layers.Input(shape=(28, 28, 1))
resNet50 = tf.keras.applications.ResNet50(include_top=False, weights=None, input_tensor=inputs, pooling='avg')
outputs = layers.Dense(LATENT_DEM, activation=None)(resNet50.output) # No activation on final dense layer
outputs = layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(outputs) # L2 normalize embedding

siamese_model = models.Model(inputs=inputs, outputs=outputs)

# Compile the model
siamese_model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tfa.losses.TripletSemiHardLoss())

# Train the network
history = siamese_model.fit(
    train_dataset,
    epochs=3)

Upvotes: 3

Related Questions