Reputation: 1
from segment_anything import sam_model_registry, SamPredictor
import os
import torch.nn as nn
import torch
import torch.distributed as dist
import argparse
import os
from torch import optim
from torch.utils.data import DataLoader
from DataLoader import TrainingDataset, stack_dict_batched
from utils import FocalDiceloss_IoULoss, get_logger, generate_point, setting_prompt_none
from metrics import SegMetrics
import time
from tqdm import tqdm
import numpy as np
import datetime
from torch.nn import functional as F
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import GradScaler, autocast
import time
import random
from segment_anything.modeling import Sam, PromptEncoder, MaskDecoder,TwoWayTransformer
from efficientvit.models.efficientvit.efficientvim import image_encoder_efficientvim
from segment_any.modeling.image_encoder_vim import ImageEncoderViM
def setup_ddp(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
def cleanup_ddp():
dist.destroy_process_group()
def _build_mamba_sam(
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 256
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Sam(
image_encoder=ImageEncoderViM(img_size=image_size),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
return sam
def _build_efiicientvim_sam(
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 256
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Sam(
image_encoder=image_encoder_efficientvim,
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
return sam
efficientvim_sam = _build_efiicientvim_sam()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--work_dir", type=str, default="workdir", help="work dir")
parser.add_argument("--run_name", type=str, default="mamba-sam-med2d", help="run model name")
parser.add_argument("--epochs", type=int, default=3, help="number of epochs")
parser.add_argument("--batch_size", type=int, default=60, help="train batch size")
parser.add_argument("--image_size", type=int, default=256, help="image_size")
parser.add_argument("--mask_num", type=int, default=5, help="get mask number")
parser.add_argument("--data_path", type=str, default="/mnt/dataset/SAMed2Dv1", help="train data path")
parser.add_argument("--metrics", nargs='+', default=['iou', 'dice'], help="metrics")
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
parser.add_argument("--resume", type=str, default=None, help="load resume")
parser.add_argument("--sam_checkpoint", type=str, default=None, help="sam checkpoint")
parser.add_argument("--iter_point", type=int, default=5, help="point iterations")
parser.add_argument('--lr_scheduler', type=str, default=None, help='lr scheduler')
parser.add_argument("--point_list", type=list, default=[1, 3, 5, 9], help="point_list")
parser.add_argument("--multimask", type=bool, default=True, help="ouput multimask")
parser.add_argument("--used_amp", type=bool, default=True, help="use amp")
args = parser.parse_args()
if args.resume is not None:
args.sam_checkpoint = None
return args
def to_device(batch_input, device):
device_input = {}
for key, value in batch_input.items():
if value is not None:
if key=='image' or key=='label':
device_input[key] = value.float().to(device)
elif type(value) is list or type(value) is torch.Size:
device_input[key] = value
else:
device_input[key] = value.to(device)
else:
device_input[key] = value
return device_input
def prompt_and_decoder(args, batched_input, model, image_embeddings, decoder_iter = False):
if batched_input["point_coords"] is not None:
points = (batched_input["point_coords"], batched_input["point_labels"])
else:
points = None
if decoder_iter:
with torch.no_grad():
sparse_embeddings, dense_embeddings = model.module.prompt_encoder(
points=points,
boxes=batched_input.get("boxes", None),
masks=batched_input.get("mask_inputs", None),
)
else:
sparse_embeddings, dense_embeddings = model.module.prompt_encoder(
points=points,
boxes=batched_input.get("boxes", None),
masks=batched_input.get("mask_inputs", None),
)
low_res_masks, iou_predictions = model.module.mask_decoder(
image_embeddings = image_embeddings.to(dtype=torch.float16 if args.used_amp else torch.float32),
image_pe = model.module.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=args.multimask,
)
if args.multimask:
max_values, max_indexs = torch.max(iou_predictions, dim=1)
max_values = max_values.unsqueeze(1)
iou_predictions = max_values
low_res = []
for i, idx in enumerate(max_indexs):
low_res.append(low_res_masks[i:i+1, idx])
low_res_masks = torch.stack(low_res, 0)
masks = F.interpolate(low_res_masks,(args.image_size, args.image_size), mode="bilinear", align_corners=False,)
return masks, low_res_masks, iou_predictions
def train_one_epoch(args, model, optimizer, train_loader, epoch, criterion, scaler):
train_loader = tqdm(train_loader)
train_losses = []
train_iter_metrics = [0] * len(args.metrics)
last_save_time = time.time()
for batch, batched_input in enumerate(train_loader):
try:
batched_input = stack_dict_batched(batched_input)
batched_input = to_device(batched_input, args.device)
if random.random() > 0.5:
batched_input["point_coords"] = None
flag = "boxes"
else:
batched_input["boxes"] = None
flag = "point"
for _, value in model.module.image_encoder.named_parameters():
value.requires_grad = True
with autocast(enabled=args.used_amp):
labels = batched_input["label"]
image_embeddings = model.module.image_encoder(batched_input["image"])
batch, _, _, _ = image_embeddings.shape
image_embeddings_repeat = []
for i in range(batch):
image_embed = image_embeddings[i]
image_embed = image_embed.repeat(args.mask_num, 1, 1, 1)
image_embeddings_repeat.append(image_embed)
image_embeddings = torch.cat(image_embeddings_repeat, dim=0)
masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings, decoder_iter=False)
loss = criterion(masks, labels, iou_predictions)
scaler.scale(loss).backward(retain_graph=False)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if int(batch+1) % 50 == 0:
print(f'Epoch: {epoch+1}, Batch: {batch+1}, first {flag} prompt: {SegMetrics(masks, labels, args.metrics)}')
point_num = random.choice(args.point_list)
batched_input = generate_point(masks, labels, low_res_masks, batched_input, point_num)
batched_input = to_device(batched_input, args.device)
image_embeddings = image_embeddings.detach().clone()
for n, value in model.named_parameters():
if "image_encoder" in n:
value.requires_grad = False
else:
value.requires_grad = True
init_mask_num = np.random.randint(1, args.iter_point - 1)
for iter in range(args.iter_point):
if iter == init_mask_num or iter == args.iter_point - 1:
batched_input = setting_prompt_none(batched_input)
with autocast(enabled=args.used_amp):
masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings, decoder_iter=True)
loss = criterion(masks, labels, iou_predictions)
scaler.scale(loss).backward(retain_graph=True)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if iter != args.iter_point - 1:
point_num = random.choice(args.point_list)
batched_input = generate_point(masks, labels, low_res_masks, batched_input, point_num)
batched_input = to_device(batched_input, args.device)
if int(batch+1) % 50 == 0:
if iter == init_mask_num or iter == args.iter_point - 1:
print(f'Epoch: {epoch+1}, Batch: {batch+1}, mask prompt: {SegMetrics(masks, labels, args.metrics)}')
else:
print(f'Epoch: {epoch+1}, Batch: {batch+1}, point {point_num} prompt: { SegMetrics(masks, labels, args.metrics)}')
if int(batch+1) % 200 == 0:
print(f"epoch:{epoch+1}, iteration:{batch+1}, loss:{loss.item()}")
save_path = os.path.join(f"{args.work_dir}/models", args.run_name, f"epoch{epoch+1}_batch{batch+1}_sam.pth")
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, save_path)
current_time = time.time()
if current_time - last_save_time >= 10800:
hours_passed = (current_time - last_save_time) // 3600
formatted_time = datetime.datetime.fromtimestamp(current_time).strftime('%Y%m%d-%H%M%S')
print(f"Saving model at epoch {epoch+1}, batch {batch+1}, time {formatted_time} due to time limit.")
save_filename = f"epoch{epoch+1}_batch{batch+1}_time{formatted_time}.pth"
save_path = os.path.join(args.work_dir, "models", args.run_name, save_filename)
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, save_path)
last_save_time = current_time
train_losses.append(loss.item())
gpu_info = {}
gpu_info['gpu_name'] = args.device
train_loader.set_postfix(train_loss=loss.item(), gpu_info=gpu_info)
train_batch_metrics = SegMetrics(masks, labels, args.metrics)
train_iter_metrics = [train_iter_metrics[i] + train_batch_metrics[i] for i in range(len(args.metrics))]
except Exception as e:
print(f"Skipping batch {batch} due to an exception: {e}")
continue
return train_losses, train_iter_metrics
def main(args):
model = nn.DataParallel(_build_mamba_sam(args.sam_checkpoint)).to(args.device)
for name, param in model.named_parameters():
print(f"{name}: {param.device}")
# if torch.cuda.device_count() > 1:
# print("Let's use", torch.cuda.device_count(), "GPUs!")
# model = nn.DataParallel(model)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion = FocalDiceloss_IoULoss()
if args.lr_scheduler:
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10], gamma = 0.5)
print('*******Use MultiStepLR')
if args.resume is not None:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print(f"Loaded weights and optimizer state from {args.resume}")
if args.used_amp:
print("Use mixed precision")
else:
print('*******Do not use mixed precision')
train_dataset = TrainingDataset(args.data_path, image_size=args.image_size, mode='train', point_num=1, mask_num=args.mask_num, requires_name=False)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=40)
print('*******Train data:', len(train_dataset))
loggers = get_logger(os.path.join(args.work_dir, "logs", f"{args.run_name}_{datetime.datetime.now().strftime('%Y%m%d-%H%M.log')}"))
best_loss = 1e10
l = len(train_loader)
for epoch in range(0, args.epochs):
scaler = GradScaler()
model.train()
train_metrics = {}
start = time.time()
os.makedirs(os.path.join(f"{args.work_dir}/models", args.run_name), exist_ok=True)
train_losses, train_iter_metrics = train_one_epoch(args, model, optimizer, train_loader, epoch, criterion,scaler)
if args.lr_scheduler is not None:
scheduler.step()
train_iter_metrics = [metric / l for metric in train_iter_metrics]
train_metrics = {args.metrics[i]: '{:.4f}'.format(train_iter_metrics[i]) for i in range(len(train_iter_metrics))}
average_loss = np.mean(train_losses)
lr = scheduler.get_last_lr()[0] if args.lr_scheduler is not None else args.lr
loggers.info(f"epoch: {epoch + 1}, lr: {lr}, Train loss: {average_loss:.4f}, metrics: {train_metrics}")
if average_loss < best_loss:
best_loss = average_loss
save_path = os.path.join(args.work_dir, "models", args.run_name, f"epoch{epoch+1}_sam.pth")
state = {'model': model.float().state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, save_path)
end = time.time()
print("Run epoch time: %.2fs" % (end - start))
if __name__ == '__main__':
args = parse_args()
main(args)
I'm training on a Docker container on a remote Linux server equipped with six GPUs, and all the environments have been set up correctly. I'm using DataParallel to wrap my model. When I run my code, the training proceeds smoothly, but when I check with nvidia-smi, I find that only one GPU is being utilized for training. I'm curious to know why this is happening.
I've verified that all GPUs are available on my server, and I have checked for any version-related issues, finding nothing abnormal. What puzzles me most is that the code runs but only utilizes one GPU. I also noticed that with the setting parser.add_argument('--device', type=str, default='cuda'), training occurs on cuda:0. However, if I change this to default='cuda:1', then training moves to cuda:1.
This behavior indicates that the script is capable of running on a specified GPU, but it doesn't automatically distribute the workload across all available GPUs, despite my server having six of them. I am looking for a way to effectively utilize all GPUs for training to improve performance and efficiency.
Upvotes: 0
Views: 43