arturo salmi
arturo salmi

Reputation: 31

Speed up data reading in pytorch dataloader

I am currently training a GAN model using two datasets with png images shaped 1040x1920 using pytorch. I am using this dataloader to load the images during training:


import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms 
from torchvision.transforms import functional as F 
from PIL import Image

class TrainImageDataset(Dataset):
    def __init__(self, path_real, path_fake, img_size=256):
        super(TrainImageDataset, self).__init__()

        self.real_images= [os.path.join(path_real, x) for x in os.listdir(path_real)]
        self.fake_images = [os.path.join(path_fake, x) for x in os.listdir(path_fake)]

        self.downscale = transforms.RandomCrop(img_size)
        self.hflip = transforms.RandomHorizontalFlip(p=0.5)

    def __getitem__(self, batch_index):
        # Load the images
        real= Image.open(self.real_images[batch_index])
        fake = Image.open(self.fake_images[batch_index])

        # Apply augmentation functions
        fake = self.downscale(fake)
        real = self.downscale(real)
        fake = self.hflip(fake)
        real = self.hflip(real)

        # Convert the images to torch tensors
        real = F.to_tensor(real)
        fake = F.to_tensor(fake)
        return {'fake':fake , 'real':real}

    def __len__(self):
        return len(self.fake_images)

When training, I then pass the dataset into a DataLoader setting batch_size=8, num_workers=4, shuffle=True, pin_memory=True, drop_last=True.

At the moment, I started using a much lighter model, and due to this the GPU Utilisation went from fixed 100% to averaging 27%, as I assume reading from disk probably takes longer than a training iteration. I have tried to move the augmentations to GPU but it's not convenient as the program now needs to load the entire 1040x1920 image to the GPU instead of the 256x256 crop.

Are there any alternatives I could use to speed up the data loading?

Upvotes: 0

Views: 2899

Answers (1)

dinarkino
dinarkino

Reputation: 394

The easiest way to check whether it is a disk reading problem, is to replace image load with fixed numpy array. Then you will clearly see if the data loading is bottleneck. Then you can do the same for augmentations and other data processing technics by turning them off. Pyrotch profiler could help here.

If it's data loading problem, there are several ways to fix it, such as multithread loading (num_workers argument), data caching, use of different libraries, saving of preprocessed labels or images. You can find some of these ideas explained in this answer.

Also, be careful with pin_memory, since it could lead to the CPU problems depending on details of your data and hardware. It's better to start with pin_memory=False.

Upvotes: 2

Related Questions