Reputation: 2478
I am trying to implement multi-GPU single machine training with PyTorch and DDP.
My dataset and dataloader looks as:
# Define transformations using albumentations-
transform_train = A.Compose(
[
# A.Resize(width = 32, height = 32),
# A.RandomCrop(width = 20, height = 20),
A.Rotate(limit = 40, p = 0.9, border_mode = cv2.BORDER_CONSTANT),
A.HorizontalFlip(p = 0.5),
A.VerticalFlip(p = 0.1),
A.RGBShift(r_shift_limit = 25, g_shift_limit = 25, b_shift_limit = 25, p = 0.9),
A.OneOf([
A.Blur(blur_limit = 3, p = 0.5),
A.ColorJitter(p = 0.5),
], p = 1.0),
A.Normalize(
# mean = [0.4914, 0.4822, 0.4465],
# std = [0.247, 0.243, 0.261],
mean = [0, 0, 0],
std = [1, 1, 1],
max_pixel_value = 255,
),
# This is not dividing by 255, which it does in PyTorch-
ToTensorV2(),
]
)
transform_test = A.Compose(
[
A.Normalize(
mean = [0, 0, 0],
std = [1, 1, 1],
max_pixel_value = 255
),
ToTensorV2()
]
)
class Cifar10Dataset(torchvision.datasets.CIFAR10):
def __init__(
self, root = "~/data/cifar10",
train = True, download = True,
transform = None
):
super().__init__(
root = root, train = train,
download = download, transform = transform
)
def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
if self.transform is not None:
transformed = self.transform(image = image)
image = transformed["image"]
return image, label
def get_cifar10_data(
rank, world_size,
path_to_files, num_workers,
batch_size = 256, pin_memory = False
):
"""
Split the dataloader
We can split our dataloader with 'torch.utils.data.distributed.DistributedSampler'.
The sampler returns a iterator over indices, which are fed into dataloader to bachify.
The 'DistributedSampler' splits the total indices of the dataset into 'world_size' parts,
and evenly distributes them to the dataloader in each process without duplication.
'DistributedSampler' imposes even partition of indices.
You might set 'num_workers = 0' for distributed training, because creating extra threads in
the children processes may be problemistic. The author also found 'pin_memory = False' avoids
many horrible bugs, maybe such things are machine-specific.
"""
# Define train and test sets-
train_dataset = Cifar10Dataset(
root = path_to_files, train = True,
download = True, transform = transform_train
)
test_dataset = Cifar10Dataset(
root = path_to_files, train = False,
download = True, transform = transform_test
)
train_sampler = DistributedSampler(
dataset = train_dataset, num_replicas = world_size,
rank = rank, shuffle = False,
drop_last = False
)
test_sampler = DistributedSampler(
dataset = test_dataset, num_replicas = world_size,
rank = rank, shuffle = False,
drop_last = False
)
# Define train and test loaders-
train_loader = torch.utils.data.DataLoader(
dataset = train_dataset, batch_size = batch_size,
pin_memory = pin_memory, shuffle = False,
num_workers = num_workers, sampler = train_sampler,
drop_last = False
)
test_loader = torch.utils.data.DataLoader(
dataset = test_dataset, batch_size = batch_size,
pin_memory = pin_memory, shuffle = False,
num_workers = num_workers, sampler = test_sampler,
drop_last = False
)
return train_loader, test_loader, train_dataset, test_dataset
The rest of the code is:
def setup_process_group(rank, world_size):
# function to setup the process group.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12344"
dist.init_process_group("nccl", rank = rank, world_size = world_size)
return None
def cleanup_process_group():
dist.destroy_process_group()
def main(rank, world_size):
# Setup process groups-
setup_process_group(rank = rank, world_size = world_size)
# Get distributed datasets and data loaders-
train_loader, test_loader, train_dataset, test_dataset = get_cifar10_data(
rank, world_size,
path_to_files, num_workers = 0,
batch_size = 256, pin_memory = False
)
# Initialize model and move to correct device-
model = ResNet50(beta = 1.0).to(rank)
"""
Wrap model in DDP
'device_id' tells DDP where your model is. 'output_device' tells DDP where to output.
In this case, it is rank.
'find_unused_parameters = True' instructs DDP to find unused output of the forward()
function of any module in the model
"""
model = DDP(module = model, device_ids = [rank], output_device = rank, find_unused_parameters = True)
# Define loss function and optimizer-
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
params = model.parameters(), lr = 1e-3,
momentum = 0.9, weight_decay = 5e-4
)
best_acc = 50
for epoch in range(1, 51):
# When using DistributedSampler, we have to tell which epoch this is-
train_loader.sampler.set_epoch(train_loader)
running_loss = 0.0
running_corrects = 0.0
for step, x in enumerate(train_loader):
optimizer.zero_grad(set_to_none = True)
output = model(x[0])
loss = loss_fn(output, x[1])
loss.backward()
optimizer.step()
# Compute model's performance statistics-
running_loss += loss.item() * x[0].size(0)
_, predicted = torch.max(output, 1)
running_corrects += torch.sum(predicted == x[1].data)
train_loss = running_loss / len(train_dataset)
train_acc = (running_corrects.double() / len(train_dataset)) * 100
print(f"epoch = {epoch}, loss = {train_loss:.5f}, acc = {train_acc:.3f}%")
if train_acc > best_acc:
best_acc = train_acc
print(f"saving best acc model = {best_acc:.3f}%")
# Save best model-
torch.save(model.module.state_dict(), "ResNet50_swish_best_trainacc.pth")
cleanup_process_group()
if __name__ == "__main__":
# Say, e have 4 GPUs-
# world_size = 4
world_size = torch.cuda.device_count()
print(f"world size = {world_size}")
mp.spawn(
main, args = (world_size),
nprocs = world_size
)
On executing this, I get the error:
Traceback (most recent call last): File "/home/majumdar/Deep_Learning/PyTorch_DDP_Tutorial/PyTorch_DDP_Tutorial.py", line 184, in mp.spawn( File "/home/majumdar/anaconda3/envs/torch-cuda-new/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn return start_processes(fn, args, nprocs, join, daemon, start_method='spawn') File "/home/majumdar/anaconda3/envs/torch-cuda-new/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes while not context.join(): File "/home/majumdar/anaconda3/envs/torch-cuda-new/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 160, in join raise ProcessRaisedException(msg, error_index, failed_process.pid) torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 7 terminated with the following error: Traceback (most recent call last): File "/home/majumdar/anaconda3/envs/torch-cuda-new/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap fn(i, *args) TypeError: Value after * must be an iterable, not int
Upvotes: 0
Views: 2144
Reputation: 1
At the mp.spawn function, add the comma after world_size
mp.spawn(
main, args = (world_size),
nprocs = world_size
)
to
mp.spawn(
main, args = (world_size, ),
nprocs = world_size
)
Upvotes: 0