Reputation: 1
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from pytorch_metric_learning.losses import ArcFaceLoss
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
from torch.utils.data import Dataset
import os
import cv2
from torchvision.models import ResNet18_Weights
import torchvision
import multiprocessing
import yaml
import argparse
from tqdm import tqdm
import torch.nn.functional as F
from torchvision.transforms import RandAugment
import random
from torch.optim.lr_scheduler import ReduceLROnPlateau
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class ResNet18ForEmbedding(nn.Module):
def __init__(self, embedding_size=512):
super(ResNet18ForEmbedding, self).__init__()
self.backbone = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
num_ftrs = self.backbone.fc.in_features
self.backbone.fc = nn.Linear(num_ftrs, embedding_size)
def forward(self, x):
embeddings = self.backbone(x)
return embeddings
class FaceImageDataSet(Dataset):
def __init__(self, data_dir, transform=None):
self.image_paths = []
self.data_dir = data_dir
self.transform = transform
self.labels = []
for label in os.listdir(data_dir):
label_dir = os.path.join(data_dir, label)
if os.path.isdir(label_dir):
for img_name in os.listdir(label_dir):
img_path = os.path.join(label_dir, img_name)
self.image_paths.append(img_path)
self.labels.append(int(label))
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = cv2.imread(self.image_paths[idx])
if self.transform:
image = self.transform(image)
label = self.labels[idx]
return image, label
class FacePairDataSet(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
self.image_paths = []
self.labels = []
for label in os.listdir(data_dir):
label_dir = os.path.join(data_dir, label)
if os.path.isdir(label_dir):
for img_name in os.listdir(label_dir):
img_path = os.path.join(label_dir, img_name)
self.image_paths.append(img_path)
self.labels.append(label)
self.class_to_idx = {label: idx for idx, label in enumerate(sorted(set(self.labels)))}
self.num_classes = len(self.class_to_idx)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path_1 = self.image_paths[idx]
label_1 = self.labels[idx]
img_1 = cv2.imread(img_path_1)
if random.random() > 0.5:
same_label_indices = [i for i, label in enumerate(self.labels) if label == label_1]
positive_idx = random.choice(same_label_indices)
img_path_2 = self.image_paths[positive_idx]
img_2 = cv2.imread(img_path_2)
label_pair = 1
else:
different_label_indices = [i for i, label in enumerate(self.labels) if label != label_1]
negative_idx = random.choice(different_label_indices)
img_path_2 = self.image_paths[negative_idx]
img_2 = cv2.imread(img_path_2)
label_pair = 0
if self.transform:
img_1 = self.transform(img_1)
img_2 = self.transform(img_2)
return (img_1, img_2), label_pair
def create_dataloaders(train_dir, val_dir, batch_size=32):
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((112, 112)),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=(3, 5))], p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=30),
transforms.RandomResizedCrop((112, 112), scale=(0.8, 1.0), ratio=(0.75, 1.33)),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
RandAugment(num_ops=2, magnitude=9),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = FaceImageDataSet(train_dir, transform=transform)
val_dataset = FacePairDataSet(val_dir, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=multiprocessing.cpu_count()//2, pin_memory=True, prefetch_factor=6, persistent_workers=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=multiprocessing.cpu_count()//2, pin_memory=True, prefetch_factor=6, persistent_workers=True)
return train_dataloader, val_dataloader
def read_data_config(config_file):
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
return config
def create_save_dir(save_dir,name):
os.makedirs(save_dir, exist_ok=True)
if not os.path.isdir(os.path.join(save_dir,name)):
os.makedirs(os.path.join(save_dir,name))
return os.path.join(save_dir,name)
counter = 1
while os.path.exists(os.path.join(save_dir,f'{name}{counter}')):
counter += 1
os.makedirs(os.path.join(save_dir,f'{name}{counter}'))
return os.path.join(save_dir,f'{name}{counter}')
def cosine_similarity(embedding1, embedding2):
cosine_sim = F.cosine_similarity(embedding1, embedding2)
return cosine_sim
def verification_accuracy(model, dataloader, threshold=0.7):
model.eval()
correct = 0
total = 0
running_accuracy = 0
pbar = tqdm(dataloader, desc=f'Val', dynamic_ncols=True)
with torch.no_grad():
for (anchor_img, image), label in pbar:
anchor_img = anchor_img.to(device)
image = image.to(device)
label = label.to(device)
anchor_embedding = model(anchor_img)
image_embedding = model(image)
similarity = cosine_similarity(anchor_embedding, image_embedding)
correct_batch = ((similarity > threshold) & (label == 1)) | ((similarity < threshold) & (label == 0))
correct += correct_batch.sum().item()
total += label.size(0)
running_accuracy = correct / total
pbar.set_postfix(accuracy=f'{running_accuracy:.3f}')
accuracy = correct / total
print(f'\tAccuracy: {accuracy:.3f}\n')
return accuracy
def train_model(model, optimizer, scheduler, loss_optimizer, arcface_loss, num_epochs, train_dataloader, val_dataloader, save_dir, name, checkpoint=None):
checkpoint_path = create_save_dir(save_dir,name)
dataset_size = len(train_dataloader.dataset)
start_epoch = 0
best_accuracy = 0
if checkpoint and os.path.isfile(checkpoint):
print(f"Loading checkpoint from {checkpoint}...")
checkpoint_data = torch.load(checkpoint, map_location=device, weights_only=False)
model.load_state_dict(checkpoint_data['model_state_dict'])
optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])
loss_optimizer.load_state_dict(checkpoint_data['loss_optimizer_state_dict'])
arcface_loss.load_state_dict(checkpoint_data['arcface_loss_state_dict']),
scheduler.load_state_dict(checkpoint_data['scheduler_state_dict'])
start_epoch = checkpoint_data['epoch'] + 1
best_accuracy = checkpoint_data['best_accuracy']
print(f"Resumed from epoch {start_epoch} with accuracy: {best_accuracy:.3f}")
for epoch in range(start_epoch, num_epochs):
print('-' * 10)
print(f'Epoch {epoch + 1}/{num_epochs}')
print(f'Learning rate: {optimizer.param_groups[0]["lr"]}')
print('-' * 10)
model.train()
running_loss = 0.0
pbar = tqdm(train_dataloader, desc=f'Train', dynamic_ncols=True)
for image, label in pbar:
image = image.to(device)
label = label.to(device)
embeddings = model(image)
loss = arcface_loss(embeddings, label)
optimizer.zero_grad()
loss_optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_optimizer.step()
running_loss += loss.item() * label.size(0)
pbar.set_postfix({
'loss': f'{running_loss / dataset_size:.2f}'
})
epoch_loss = running_loss / dataset_size
print(f'\tLoss: {epoch_loss:.2f}\n')
accuracy = verification_accuracy(model, val_dataloader, 0.7)
scheduler.step(accuracy)
if accuracy > best_accuracy:
best_accuracy = accuracy
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss_optimizer_state_dict': loss_optimizer.state_dict(),
'arcface_loss_state_dict': arcface_loss.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_accuracy': best_accuracy
}, os.path.join(checkpoint_path, 'best.pth'))
print(f"Saved best model to {os.path.join(checkpoint_path, 'best.pth')}")
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss_optimizer_state_dict': loss_optimizer.state_dict(),
'arcface_loss_state_dict': arcface_loss.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_accuracy': best_accuracy
}, os.path.join(checkpoint_path, 'last.pth'))
print(f"Saved last model to {os.path.join(checkpoint_path, 'last.pth')}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Train a face embedding model")
parser.add_argument("--data_config", type=str, default="data/small_train.yml", help="Path to data config file")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training")
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate")
parser.add_argument("--embedding_size", type=int, default=512, help="Size of embedding vector")
parser.add_argument("--patience", type=int, default=5, help="Early stopping patience")
parser.add_argument("--save_dir", type=str, default="runs/train", help="Directory to save model")
parser.add_argument("--name", type=str, default='train',help="Name of the current run to save")
parser.add_argument("--checkpoint", type=str, default=None, help="Path to a checkpoint file to resume training")
args = parser.parse_args()
yaml_config = read_data_config(args.data_config)
train_dataloader, val_dataloader = create_dataloaders(yaml_config.get('train_data_dir', None), yaml_config.get('val_data_dir', None), args.batch_size)
model = ResNet18ForEmbedding(embedding_size=args.embedding_size)
model = model.to(device)
arcface_loss = ArcFaceLoss(num_classes=len(os.listdir(yaml_config.get('train_data_dir', None))), embedding_size=args.embedding_size, margin=28.6, scale=64.0).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
loss_optimizer = torch.optim.SGD(arcface_loss.parameters(), lr=args.learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=args.patience)
train_model(model, optimizer, scheduler, loss_optimizer, arcface_loss, args.epochs, train_dataloader, val_dataloader, args.save_dir, args.name, args.checkpoint)
During training of the model the first epoch has a loss of 36 and an accuracy on validation of 0.49. On second epoch it gets an training loss of 34 and an accuracy of 0.74 on validation. From there on the train loss is going down but the validation accuracy drops continuously. The dataset that i am using is digiface1m.
I tried different learning rates but they all produce the same results. I tried to reduce the augmentation, add dropout layers but got the same result.
Upvotes: 0
Views: 26