Arun
Arun

Reputation: 2478

Using PyTorch's DDP for multi-GPU training with mp.spawn() doesn't work

I am trying to implement multi-GPU single machine training with PyTorch and DDP.

My dataset and dataloader looks as:

# Define transformations using albumentations-
transform_train = A.Compose(
    [
        # A.Resize(width = 32, height = 32),
        # A.RandomCrop(width = 20, height = 20),
        A.Rotate(limit = 40, p = 0.9, border_mode = cv2.BORDER_CONSTANT),
        A.HorizontalFlip(p = 0.5),
        A.VerticalFlip(p = 0.1),
        A.RGBShift(r_shift_limit = 25, g_shift_limit = 25, b_shift_limit = 25, p = 0.9),
        A.OneOf([
            A.Blur(blur_limit = 3, p = 0.5),
            A.ColorJitter(p = 0.5),
        ], p = 1.0),
        A.Normalize(
            # mean = [0.4914, 0.4822, 0.4465],
            # std = [0.247, 0.243, 0.261],
            mean = [0, 0, 0],
            std = [1, 1, 1],
            max_pixel_value = 255,
        ),
        # This is not dividing by 255, which it does in PyTorch-
        ToTensorV2(),
    ]
)

transform_test = A.Compose(
    [
        A.Normalize(
            mean = [0, 0, 0],
            std = [1, 1, 1],
            max_pixel_value = 255
        ),
        ToTensorV2()
    ]
)


class Cifar10Dataset(torchvision.datasets.CIFAR10):
    def __init__(
        self, root = "~/data/cifar10",
        train = True, download = True,
        transform = None
    ):
        super().__init__(
            root = root, train = train,
            download = download, transform = transform
        )

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image = image)
            image = transformed["image"]

        return image, label



def get_cifar10_data(
    rank, world_size,
    path_to_files, num_workers,
    batch_size = 256, pin_memory = False
    ):
    """
    Split the dataloader

    We can split our dataloader with 'torch.utils.data.distributed.DistributedSampler'.
    The sampler returns a iterator over indices, which are fed into dataloader to bachify.
    The 'DistributedSampler' splits the total indices of the dataset into 'world_size' parts,
    and evenly distributes them to the dataloader in each process without duplication.

    'DistributedSampler' imposes even partition of indices.

    You might set 'num_workers = 0' for distributed training, because creating extra threads in
    the children processes may be problemistic. The author also found 'pin_memory = False' avoids
    many horrible bugs, maybe such things are machine-specific.
    """

    # Define train and test sets-
    train_dataset = Cifar10Dataset(
        root = path_to_files, train = True,
        download = True, transform = transform_train
    )

    test_dataset = Cifar10Dataset(
        root = path_to_files, train = False,
        download = True, transform = transform_test
    )

    train_sampler = DistributedSampler(
        dataset = train_dataset, num_replicas = world_size,
        rank = rank, shuffle = False,
        drop_last = False
        )

    test_sampler = DistributedSampler(
        dataset = test_dataset, num_replicas = world_size,
        rank = rank, shuffle = False,
        drop_last = False
        )

    # Define train and test loaders-
    train_loader = torch.utils.data.DataLoader(
        dataset = train_dataset, batch_size = batch_size,
        pin_memory = pin_memory, shuffle = False,
        num_workers = num_workers, sampler = train_sampler,
        drop_last = False
    )

    test_loader = torch.utils.data.DataLoader(
        dataset = test_dataset, batch_size = batch_size,
        pin_memory = pin_memory, shuffle = False,
        num_workers = num_workers, sampler = test_sampler,
        drop_last = False
    )

    return train_loader, test_loader, train_dataset, test_dataset

The rest of the code is:

def setup_process_group(rank, world_size):
    # function to setup the process group.

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12344"

    dist.init_process_group("nccl", rank = rank, world_size = world_size)

    return None


def cleanup_process_group():
    dist.destroy_process_group()

def main(rank, world_size):
    # Setup process groups-
    setup_process_group(rank = rank, world_size = world_size)

    # Get distributed datasets and data loaders-
    train_loader, test_loader, train_dataset, test_dataset = get_cifar10_data(
        rank, world_size,
        path_to_files, num_workers = 0,
        batch_size = 256, pin_memory = False
        )


    # Initialize model and move to correct device-
    model = ResNet50(beta = 1.0).to(rank)

    """
    Wrap model in DDP
    'device_id' tells DDP where your model is. 'output_device' tells DDP where to output.
    In this case, it is rank.
    'find_unused_parameters = True' instructs DDP to find unused output of the forward()
    function of any module in the model
    """
    model = DDP(module = model, device_ids = [rank], output_device = rank, find_unused_parameters = True)

    # Define loss function and optimizer-
    loss_fn = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(
        params = model.parameters(), lr = 1e-3,
        momentum = 0.9, weight_decay = 5e-4
    )

    best_acc = 50


    for epoch in range(1, 51):
        # When using DistributedSampler, we have to tell which epoch this is-
        train_loader.sampler.set_epoch(train_loader)

        running_loss = 0.0
        running_corrects = 0.0

        for step, x in enumerate(train_loader):
            optimizer.zero_grad(set_to_none = True)
            output = model(x[0])
            loss = loss_fn(output, x[1])
            loss.backward()
            optimizer.step()

            # Compute model's performance statistics-
            running_loss += loss.item() * x[0].size(0)
            _, predicted = torch.max(output, 1)
            running_corrects += torch.sum(predicted == x[1].data)

        train_loss = running_loss / len(train_dataset)
        train_acc = (running_corrects.double() / len(train_dataset)) * 100

        print(f"epoch = {epoch}, loss = {train_loss:.5f}, acc = {train_acc:.3f}%")

        if train_acc > best_acc:
            best_acc = train_acc
            print(f"saving best acc model = {best_acc:.3f}%")

            # Save best model-
            torch.save(model.module.state_dict(), "ResNet50_swish_best_trainacc.pth")

    cleanup_process_group()

if __name__ == "__main__":
    # Say, e have 4 GPUs-
    # world_size = 4
    world_size = torch.cuda.device_count()
    print(f"world size = {world_size}")

    mp.spawn(
        main, args = (world_size),
        nprocs = world_size
    )

On executing this, I get the error:

Traceback (most recent call last): File "/home/majumdar/Deep_Learning/PyTorch_DDP_Tutorial/PyTorch_DDP_Tutorial.py", line 184, in mp.spawn( File "/home/majumdar/anaconda3/envs/torch-cuda-new/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn return start_processes(fn, args, nprocs, join, daemon, start_method='spawn') File "/home/majumdar/anaconda3/envs/torch-cuda-new/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes while not context.join(): File "/home/majumdar/anaconda3/envs/torch-cuda-new/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 160, in join raise ProcessRaisedException(msg, error_index, failed_process.pid) torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 7 terminated with the following error: Traceback (most recent call last): File "/home/majumdar/anaconda3/envs/torch-cuda-new/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap fn(i, *args) TypeError: Value after * must be an iterable, not int

Upvotes: 0

Views: 2144

Answers (1)

KukJin Kim
KukJin Kim

Reputation: 1

At the mp.spawn function, add the comma after world_size

mp.spawn(
    main, args = (world_size),
    nprocs = world_size
)

to

mp.spawn(
    main, args = (world_size, ),
    nprocs = world_size
)

Upvotes: 0

Related Questions