Sajith 17
Sajith 17

Reputation: 11

Vector collapse with batch hard triplet mining (Siamese Network)

I am having this problem for a long time while trying to recreate this paper's https://arxiv.org/pdf/1703.07737 implementation in tensorflow 2.10.0.

The problem is when i apply the same loss and data generation with simple FCNN model for MNIST dataset the model is not collapsing but when i try it on celebA dataset with with pretrained resnet50 and two dense layers with 1024 and 128 units (as mentioned on the above paper) the model gets collapsed that is all the vector representation of any images gets clustered into a single point in the vector space and loss is not going below the margin.

class BatchHardTripletLoss(tf.keras.losses.Loss): 
    def init(self, margin=0.5, squared=False): 
        super().init() 
        self.margin = margin 
        self.squared =squared

    def call(self, mask, embeddings):
    
        size = tf.shape(mask)[-1]//2
        mask_anchor_positive, mask_anchor_negative = mask[:, :size], mask[:, size:]
    
        pairwise_dist = self._pairwise_distances(embeddings, squared=self.squared)
        mask_anchor_positive = tf.cast(mask_anchor_positive, dtype=tf.float32)
        anchor_positive_dist = tf.multiply(mask_anchor_positive, pairwise_dist)
        hardest_positive_dist = tf.reduce_max(anchor_positive_dist, axis=1, keepdims=True)
    
        mask_anchor_negative = tf.cast(mask_anchor_negative, dtype=tf.float32)
        max_anchor_negative_dist = tf.reduce_max(pairwise_dist, axis=1, keepdims=True)
        anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
        hardest_negative_dist = tf.reduce_min(anchor_negative_dist, axis=1, keepdims=True)
    
        triplet_loss = tf.maximum((hardest_positive_dist - hardest_negative_dist) + self.margin, 0.0)
        triplet_loss = tf.reduce_mean(triplet_loss)
        return triplet_loss
    
    @staticmethod
    def _pairwise_distances(embeddings, squared=False):
        dot_product = tf.matmul(embeddings, tf.transpose(embeddings))
        square_norm = tf.linalg.diag_part(dot_product)
        distances = tf.expand_dims(square_norm, 0) - 2.0 * dot_product + tf.expand_dims(square_norm, 1)
        distances = tf.maximum(distances, 0.0)
        if not squared:
            mask = tf.cast(tf.equal(distances, 0.0), dtype=tf.float32)
            distances = distances + mask * 1e-16
            distances = tf.sqrt(distances)
            distances = distances * (1.0 - mask)
        return distances

base_model = tf.keras.applications.ResNet50(
    include_top=False,
    weights="imagenet",
    input_shape=(224,224,3),
    pooling='avg'
)
base_model.trainable = False
# for layer in base_model.layers[:-8]:
#     layer.trainable = False
# did not work even after leaving some top layers to trainable = True

inputs = tf.keras.Input(shape=(224, 224, 3))
x = base_model(inputs, training = False)
x = tf.keras.layers.Dense(1024, activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
output = tf.keras.layers.Dense(128)(x)
embeddings = tf.keras.Model(inputs, output)

class SiameseModel(tf.keras.models.Model):
    def __init__(self, embeddings):
        super().__init__()
        self.embeddings = embeddings
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")

    def train_step(self, data):
        X, label = data
        positive_mask, negative_mask = _get_anchor_positive_triplet_mask(label), _get_anchor_negative_triplet_mask(label)
        mask = tf.concat((positive_mask, negative_mask),-1)

        with tf.GradientTape() as tape:
            y_pred = self(X)  # Forward pass
            loss = self.loss(mask, y_pred)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def call(self, X):
        return self.embeddings(X)

generator = DataGenerator(eval_split[0], identity_images, img_dir, P=18, K = 4, dim=(224, 224, 3))
model.compile(optimizer = 'adam', loss = BatchHardTripletLoss(0.5))`

Upvotes: 1

Views: 60

Answers (0)

Related Questions