Reputation: 11
I am having this problem for a long time while trying to recreate this paper's https://arxiv.org/pdf/1703.07737 implementation in tensorflow 2.10.0.
The problem is when i apply the same loss and data generation with simple FCNN model for MNIST dataset the model is not collapsing but when i try it on celebA dataset with with pretrained resnet50 and two dense layers with 1024 and 128 units (as mentioned on the above paper) the model gets collapsed that is all the vector representation of any images gets clustered into a single point in the vector space and loss is not going below the margin.
class BatchHardTripletLoss(tf.keras.losses.Loss):
def init(self, margin=0.5, squared=False):
super().init()
self.margin = margin
self.squared =squared
def call(self, mask, embeddings):
size = tf.shape(mask)[-1]//2
mask_anchor_positive, mask_anchor_negative = mask[:, :size], mask[:, size:]
pairwise_dist = self._pairwise_distances(embeddings, squared=self.squared)
mask_anchor_positive = tf.cast(mask_anchor_positive, dtype=tf.float32)
anchor_positive_dist = tf.multiply(mask_anchor_positive, pairwise_dist)
hardest_positive_dist = tf.reduce_max(anchor_positive_dist, axis=1, keepdims=True)
mask_anchor_negative = tf.cast(mask_anchor_negative, dtype=tf.float32)
max_anchor_negative_dist = tf.reduce_max(pairwise_dist, axis=1, keepdims=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
hardest_negative_dist = tf.reduce_min(anchor_negative_dist, axis=1, keepdims=True)
triplet_loss = tf.maximum((hardest_positive_dist - hardest_negative_dist) + self.margin, 0.0)
triplet_loss = tf.reduce_mean(triplet_loss)
return triplet_loss
@staticmethod
def _pairwise_distances(embeddings, squared=False):
dot_product = tf.matmul(embeddings, tf.transpose(embeddings))
square_norm = tf.linalg.diag_part(dot_product)
distances = tf.expand_dims(square_norm, 0) - 2.0 * dot_product + tf.expand_dims(square_norm, 1)
distances = tf.maximum(distances, 0.0)
if not squared:
mask = tf.cast(tf.equal(distances, 0.0), dtype=tf.float32)
distances = distances + mask * 1e-16
distances = tf.sqrt(distances)
distances = distances * (1.0 - mask)
return distances
base_model = tf.keras.applications.ResNet50(
include_top=False,
weights="imagenet",
input_shape=(224,224,3),
pooling='avg'
)
base_model.trainable = False
# for layer in base_model.layers[:-8]:
# layer.trainable = False
# did not work even after leaving some top layers to trainable = True
inputs = tf.keras.Input(shape=(224, 224, 3))
x = base_model(inputs, training = False)
x = tf.keras.layers.Dense(1024, activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
output = tf.keras.layers.Dense(128)(x)
embeddings = tf.keras.Model(inputs, output)
class SiameseModel(tf.keras.models.Model):
def __init__(self, embeddings):
super().__init__()
self.embeddings = embeddings
self.loss_tracker = tf.keras.metrics.Mean(name="loss")
def train_step(self, data):
X, label = data
positive_mask, negative_mask = _get_anchor_positive_triplet_mask(label), _get_anchor_negative_triplet_mask(label)
mask = tf.concat((positive_mask, negative_mask),-1)
with tf.GradientTape() as tape:
y_pred = self(X) # Forward pass
loss = self.loss(mask, y_pred)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
self.loss_tracker.update_state(loss)
return {"loss": self.loss_tracker.result()}
def call(self, X):
return self.embeddings(X)
generator = DataGenerator(eval_split[0], identity_images, img_dir, P=18, K = 4, dim=(224, 224, 3))
model.compile(optimizer = 'adam', loss = BatchHardTripletLoss(0.5))`
Upvotes: 1
Views: 60