user28565548
user28565548

Reputation: 1

ArcFace Loss: Training loss decreases but validation accuracy drops too

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from pytorch_metric_learning.losses import ArcFaceLoss
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
from torch.utils.data import Dataset
import os
import cv2
from torchvision.models import ResNet18_Weights
import torchvision
import multiprocessing
import yaml
import argparse
from tqdm import tqdm
import torch.nn.functional as F
from torchvision.transforms import RandAugment
import random
from torch.optim.lr_scheduler import ReduceLROnPlateau

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class ResNet18ForEmbedding(nn.Module):
    def __init__(self, embedding_size=512):
        super(ResNet18ForEmbedding, self).__init__()
        self.backbone = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
        num_ftrs = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(num_ftrs, embedding_size)

    def forward(self, x):
        embeddings = self.backbone(x)
        return embeddings
    
class FaceImageDataSet(Dataset):
    def __init__(self, data_dir, transform=None):
        self.image_paths = []
        self.data_dir = data_dir
        self.transform = transform
        self.labels = []

        for label in os.listdir(data_dir):
            label_dir = os.path.join(data_dir, label)
            if os.path.isdir(label_dir):
                for img_name in os.listdir(label_dir):
                    img_path = os.path.join(label_dir, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(int(label))
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]

        return image, label

class FacePairDataSet(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        for label in os.listdir(data_dir):
            label_dir = os.path.join(data_dir, label)
            if os.path.isdir(label_dir):
                for img_name in os.listdir(label_dir):
                    img_path = os.path.join(label_dir, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(label)

        self.class_to_idx = {label: idx for idx, label in enumerate(sorted(set(self.labels)))}
        self.num_classes = len(self.class_to_idx)

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path_1 = self.image_paths[idx]
        label_1 = self.labels[idx]
        img_1 = cv2.imread(img_path_1)
    
        if random.random() > 0.5:
            same_label_indices = [i for i, label in enumerate(self.labels) if label == label_1]
            positive_idx = random.choice(same_label_indices)
            img_path_2 = self.image_paths[positive_idx]
            img_2 = cv2.imread(img_path_2)
            label_pair = 1
        else:
            different_label_indices = [i for i, label in enumerate(self.labels) if label != label_1]
            negative_idx = random.choice(different_label_indices)
            img_path_2 = self.image_paths[negative_idx]
            img_2 = cv2.imread(img_path_2)
            label_pair = 0

        if self.transform:
            img_1 = self.transform(img_1)
            img_2 = self.transform(img_2)
            
        return (img_1, img_2), label_pair

def create_dataloaders(train_dir, val_dir, batch_size=32):
    
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((112, 112)),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=(3, 5))], p=0.5),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=30),
        transforms.RandomResizedCrop((112, 112), scale=(0.8, 1.0), ratio=(0.75, 1.33)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        RandAugment(num_ops=2, magnitude=9),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = FaceImageDataSet(train_dir, transform=transform)
    val_dataset = FacePairDataSet(val_dir, transform=transform)
    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=multiprocessing.cpu_count()//2, pin_memory=True, prefetch_factor=6, persistent_workers=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=multiprocessing.cpu_count()//2, pin_memory=True, prefetch_factor=6, persistent_workers=True)

    return train_dataloader, val_dataloader

def read_data_config(config_file):
    with open(config_file, 'r') as f:
        config = yaml.safe_load(f)
    return config


def create_save_dir(save_dir,name):
    os.makedirs(save_dir, exist_ok=True)

    if not os.path.isdir(os.path.join(save_dir,name)):
        os.makedirs(os.path.join(save_dir,name))
        return os.path.join(save_dir,name)
    
    counter = 1
    while os.path.exists(os.path.join(save_dir,f'{name}{counter}')):
        counter += 1
    
    os.makedirs(os.path.join(save_dir,f'{name}{counter}'))
    return os.path.join(save_dir,f'{name}{counter}')

def cosine_similarity(embedding1, embedding2):
    cosine_sim = F.cosine_similarity(embedding1, embedding2)
    return cosine_sim

def verification_accuracy(model, dataloader, threshold=0.7):
    model.eval()
    correct = 0
    total = 0
    running_accuracy = 0

    pbar = tqdm(dataloader, desc=f'Val', dynamic_ncols=True)
    with torch.no_grad():
        for (anchor_img, image), label in pbar:
            anchor_img = anchor_img.to(device)
            image = image.to(device)
            label = label.to(device)

            anchor_embedding = model(anchor_img)
            image_embedding = model(image)

            similarity = cosine_similarity(anchor_embedding, image_embedding)

            correct_batch = ((similarity > threshold) & (label == 1)) | ((similarity < threshold) & (label == 0))

            correct += correct_batch.sum().item()
            total += label.size(0) 

            running_accuracy = correct / total
            pbar.set_postfix(accuracy=f'{running_accuracy:.3f}')

    accuracy = correct / total
    print(f'\tAccuracy: {accuracy:.3f}\n')
    return accuracy

def train_model(model, optimizer, scheduler, loss_optimizer, arcface_loss, num_epochs, train_dataloader, val_dataloader, save_dir, name, checkpoint=None):

    checkpoint_path = create_save_dir(save_dir,name)
    dataset_size = len(train_dataloader.dataset)
    start_epoch = 0
    best_accuracy = 0

    if checkpoint and os.path.isfile(checkpoint):
        print(f"Loading checkpoint from {checkpoint}...")
        checkpoint_data = torch.load(checkpoint, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint_data['model_state_dict'])
        optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])
        loss_optimizer.load_state_dict(checkpoint_data['loss_optimizer_state_dict'])
        arcface_loss.load_state_dict(checkpoint_data['arcface_loss_state_dict']),
        scheduler.load_state_dict(checkpoint_data['scheduler_state_dict'])
        start_epoch = checkpoint_data['epoch'] + 1
        best_accuracy = checkpoint_data['best_accuracy']
        print(f"Resumed from epoch {start_epoch} with accuracy: {best_accuracy:.3f}")

    for epoch in range(start_epoch, num_epochs):
        print('-' * 10)
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print(f'Learning rate: {optimizer.param_groups[0]["lr"]}')
        print('-' * 10)
        model.train()

        running_loss = 0.0

        pbar = tqdm(train_dataloader, desc=f'Train', dynamic_ncols=True)
        for image, label in pbar:
            image = image.to(device)
            label = label.to(device)

            embeddings = model(image)
            loss = arcface_loss(embeddings, label)

            optimizer.zero_grad()
            loss_optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_optimizer.step()

            running_loss += loss.item() * label.size(0)

            pbar.set_postfix({
                'loss': f'{running_loss / dataset_size:.2f}'
            })

        epoch_loss = running_loss / dataset_size

        print(f'\tLoss: {epoch_loss:.2f}\n')

        accuracy = verification_accuracy(model, val_dataloader, 0.7)
        scheduler.step(accuracy)
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss_optimizer_state_dict': loss_optimizer.state_dict(),
                'arcface_loss_state_dict': arcface_loss.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_accuracy': best_accuracy
            }, os.path.join(checkpoint_path, 'best.pth'))
            print(f"Saved best model to {os.path.join(checkpoint_path, 'best.pth')}")
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss_optimizer_state_dict': loss_optimizer.state_dict(),
            'arcface_loss_state_dict': arcface_loss.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_accuracy': best_accuracy
        }, os.path.join(checkpoint_path, 'last.pth'))
        print(f"Saved last model to {os.path.join(checkpoint_path, 'last.pth')}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train a face embedding model")
    parser.add_argument("--data_config", type=str, default="data/small_train.yml", help="Path to data config file")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training")
    parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate")
    parser.add_argument("--embedding_size", type=int, default=512, help="Size of embedding vector")
    parser.add_argument("--patience", type=int, default=5, help="Early stopping patience")
    parser.add_argument("--save_dir", type=str, default="runs/train", help="Directory to save model")
    parser.add_argument("--name", type=str, default='train',help="Name of the current run to save")
    parser.add_argument("--checkpoint", type=str, default=None, help="Path to a checkpoint file to resume training")

    args = parser.parse_args()
    yaml_config = read_data_config(args.data_config)


    train_dataloader, val_dataloader = create_dataloaders(yaml_config.get('train_data_dir', None), yaml_config.get('val_data_dir', None), args.batch_size)


    model = ResNet18ForEmbedding(embedding_size=args.embedding_size)
    model = model.to(device)
    arcface_loss = ArcFaceLoss(num_classes=len(os.listdir(yaml_config.get('train_data_dir', None))), embedding_size=args.embedding_size, margin=28.6, scale=64.0).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
    loss_optimizer = torch.optim.SGD(arcface_loss.parameters(), lr=args.learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=args.patience)
    train_model(model, optimizer, scheduler, loss_optimizer, arcface_loss, args.epochs, train_dataloader, val_dataloader, args.save_dir, args.name, args.checkpoint)

During training of the model the first epoch has a loss of 36 and an accuracy on validation of 0.49. On second epoch it gets an training loss of 34 and an accuracy of 0.74 on validation. From there on the train loss is going down but the validation accuracy drops continuously. The dataset that i am using is digiface1m.

I tried different learning rates but they all produce the same results. I tried to reduce the augmentation, add dropout layers but got the same result.

Upvotes: 0

Views: 26

Answers (0)

Related Questions