Reputation: 4152
I have a ResNet
based siamese
network which uses the idea that you try to minimize the l-2
distance between 2 images and then apply a sigmoid so that it gives you {0:'same',1:'different'}
output and based on how far the prediction is, you just flow the gradients back to network but there is a problem that updation of gradients is too little as we're changing the distance between {0,1}
so I thought of using the same architecture but based on Triplet Loss
.
I1 = Input(shape=image_shape)
I2 = Input(shape=image_shape)
res_m_1 = ResNet50(include_top=False, weights='imagenet', input_tensor=I1, pooling='avg')
res_m_2 = ResNet50(include_top=False, weights='imagenet', input_tensor=I2, pooling='avg')
x1 = res_m_1.output
x2 = res_m_2.output
# x = Flatten()(x) or use this one if not using any pooling layer
distance = Lambda( lambda tensors : K.abs( tensors[0] - tensors[1] )) ([x1,x2] )
final_output = Dense(1,activation='sigmoid')(distance)
siamese_model = Model(inputs=[I1,I2], outputs=final_output)
siamese_model.compile(loss='binary_crossentropy',optimizer=Adam(),metrics['acc'])
siamese_model.fit_generator(train_gen,steps_per_epoch=1000,epochs=10,validation_data=validation_data)
So how can I change it to use the Triplet Loss function
? What adjustments should be done here in order to get this done? One change will be that I'll have to calculate
res_m_3 = ResNet50(include_top=False, weights='imagenet', input_tensor=I2, pooling='avg')
x3 = res_m_3.output
One thing found in tf docs is triplet-semi-hard-loss
and is given as:
tfa.losses.TripletSemiHardLoss()
As shown in the paper, the best results are from triplets known as "Semi-Hard". These are defined as triplets where the negative is farther from the anchor than the positive, but still produces a positive loss. To efficiently find these triplets we utilize online learning and only train from the Semi-Hard examples in each batch.
Another implementation of Triplet Loss
which I found on Kaggle is: Triplet Loss Keras
Which one should I use and most importantly, HOW?
P.S: People also use something like: x = Lambda(lambda x: K.l2_normalize(x,axis=1))(x)
after model.output
. Why is that? What is this doing?
Upvotes: 1
Views: 1699
Reputation: 4313
Following this answer of mine, and with role of TripletSemiHardLoss
in mind, we could do following:
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
from tensorflow.keras import models, layers
BATCH_SIZE = 32
LATENT_DEM = 128
def _normalize_img(img, label):
img = tf.cast(img, tf.float32) / 255.
return (img, label)
train_dataset, test_dataset = tfds.load(name="mnist", split=['train', 'test'], as_supervised=True)
# Build your input pipelines
train_dataset = train_dataset.shuffle(1024).batch(BATCH_SIZE)
train_dataset = train_dataset.map(_normalize_img)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.map(_normalize_img)
inputs = layers.Input(shape=(28, 28, 1))
resNet50 = tf.keras.applications.ResNet50(include_top=False, weights=None, input_tensor=inputs, pooling='avg')
outputs = layers.Dense(LATENT_DEM, activation=None)(resNet50.output) # No activation on final dense layer
outputs = layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(outputs) # L2 normalize embedding
siamese_model = models.Model(inputs=inputs, outputs=outputs)
# Compile the model
siamese_model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tfa.losses.TripletSemiHardLoss())
# Train the network
history = siamese_model.fit(
train_dataset,
epochs=3)
Upvotes: 3