gruszczy
gruszczy

Reputation: 42168

PyTorch: Speed up data loading

I am using densenet121 to do cat/dog detection from Kaggle dataset. I enabled cuda and it appears that training is very fast. However, the data loading (or perhaps processing) appears to be very slow. Are there some ways to speed it up? I tried to play witch batch size, that didn't provide much help. I also changed num_workers from 0 to some positive numbers. Going from 0 to 2 reduces loading time by perhaps 1/3, increasing by more doesn't have additional effect. Are there some other ways I can speed loading things up?

This is my rough code (I am focused on learning, so it's not very organized):

import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

data_dir = 'Cat_Dog_data'

train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.5, 0.5, 0.5],
                                                            [0.5, 0.5, 0.5])])
test_transforms = transforms.Compose([transforms.Resize(255),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor()])

# Pass transforms in here, then run the next cell to see how the transforms look
train_data = datasets.ImageFolder(data_dir + '/train',
                                  transform=train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test', transform=test_transforms)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=64,
                                          num_workers=16, shuffle=True,
                                          pin_memory=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=64,
                                         num_workers=16)

model = models.densenet121(pretrained=True)

# Freeze parameters so we don't backprop through them
for param in model.parameters():
    param.requires_grad = False

from collections import OrderedDict

classifier = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(1024, 500)),
    ('relu', nn.ReLU()),
    ('fc2', nn.Linear(500, 2)),
    ('output', nn.LogSoftmax(dim=1))
]))

model.classifier = classifier
model.cuda()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)

epochs = 30
steps = 0

import time

device = torch.device('cuda:0')

train_losses, test_losses = [], []
for e in range(epochs):
    running_loss = 0
    count = 0
    total_start = time.time()
    for images, labels in trainloader:
        start = time.time()
        images = images.cuda()
        labels = labels.cuda()

        optimizer.zero_grad()

        log_ps = model(images)
        loss = criterion(log_ps, labels)
        loss.backward()
        optimizer.step()
        elapsed = time.time() - start

        if count % 20 == 0:
            print("Optimized elapsed: ", elapsed, "count:", count)
            print("Total elapsed ", time.time() - total_start)
            total_start = time.time()
        count += 1

        running_loss += loss.item()
    else:
        test_loss = 0
        accuracy = 0
        for images, labels in testloader:
            images = images.cuda()
            labels = labels.cuda()
            with torch.no_grad():
                model.eval()
                log_ps = model(images)
                test_loss += criterion(log_ps, labels)
                ps = torch.exp(log_ps)
                top_p, top_class = ps.topk(1, dim=1)
                compare = top_class == labels.view(*top_class.shape)
                accuracy += compare.type(torch.FloatTensor).mean()
        model.train()
        train_losses.append(running_loss / len(trainloader))
        test_losses.append(test_loss / len(testloader))

        print("Epoch: {}/{}.. ".format(e + 1, epochs),
              "Training Loss: {:.3f}.. ".format(
                  running_loss / len(trainloader)),
              "Test Loss: {:.3f}.. ".format(test_loss / len(testloader)),
              "Test Accuracy: {:.3f}".format(accuracy / len(testloader)))

Upvotes: 10

Views: 32333

Answers (2)

RUser4512
RUser4512

Reputation: 1074

Late answer but I was able to achieve a 2x speed up on one of my PyTorch image analysis project, I hope this can be of help!

Basically, the main ideas were to:

  • load all the images in the RAM (it may not be possible depending and your hardware and dataset size though),
  • factor all the transforms that can be factore (i.e. factor resizing, ToTensor... but do not factor the data augmentation parts)

The class below loads all the images in an array when __init__ is called and performs the transformations (ToTensor and Resize)

class Sentinel2Dataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):

        factored_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Resize(334, antialias=False, interpolation=InterpolationMode.NEAREST_EXACT)])

        self.file_paths = file_paths
        self.images = []
        for file_path in tqdm(self.file_paths):
            image = load_and_convert_tiff(file_path)
            transformed_image = factored_transform(image)
            self.images.append(transformed_image)

        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform is not None:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

Extra speedup can be achieved by tuning these flags:

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

These ideas are exposed in more details here: speed up PyTorch loading

Upvotes: 1

Szymon Maszke
Szymon Maszke

Reputation: 24691

torchvision 0.8.0 version or greater

Actually torchvision now supports batches and GPU when it comes to transformations (this is done on torch.Tensors instead of PIL images), so one should use it as an initial improvement.

See here for more info about this release. Also those act as torch.nn.Module, hence can be used inside a model, for example:

transforms = torch.nn.Sequential(
    T.RandomCrop(224),
    T.RandomHorizontalFlip(p=0.3),
    T.ConvertImageDtype(torch.float),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
)

Furthermore, those operations could be JITed possibly improving the performance even further.

torchvision < 0.8.0 (original answer)

Increasing batch_size won't help as torchvision performs transform on single image while it's loaded from your disk.

There are a couple of ways one could speed up data loading with increasing level of difficulty:

  • Improve image loading times
  • Load & normalize images and cache in RAM (or on disk)
  • Produce transformations and save them to disk
  • Apply non-cache'able transforms (rotations, flips, crops) in batched manner
  • Prefetching

1. Improve image loading

Easy improvements can be gained by installing Pillow-SIMD instead of original pillow. It is a drop-in replacement and could be faster (or so is claimed at least for Resize which you are using).

Alternatively, you could create your own data loading and processing with OpenCV as some say it's faster or check albumentations (though can't tell you whether those will improve the performance and might be a lot of time wasted for no gain except learning experience).

2. Load & normalize images & cache

You can use Python's LRU Cache functionality to cache some outputs.

You can also use torchdata which acts almost exactly like PyTorch's torch.utils.data.Dataset but allows caching to disk or in RAM (or mixed modes) with simple cache() on torchdata.Dataset (see github repository, disclaimer: i'm the author).

Remember: you have to load and normalize images, cache and after that use RandomRotation, RandomResizedCrop and RandomHorizontalFlip (as those change each time they are run).

3. Produce transformations and save them to disk

You would have to perform a lot of transformations on images, save them to disk and use this enhanced dataset afterwards. Once again that could be done with torchdata but it's really wasteful when it comes to I/O and hard drive and very inelegant solution. Furthermore it's "static" so the data would only last your for X epochs, it wouldn't be "infinite" generator with augmentations.

4. Batched transformations

torchvision does not support it so you would have to write those functions on your own. See this issue for justification. AFAIK no other 3rd party provides it either. For large batches it should speed up things but implementation is open question I think (correct me if I'm wrong).

5. Prefetch

IMO would be hardest to implement (though a really good idea for the project come to think about it). Basically you load data for the next iteration when your model trains. torch.utils.data.DataLoader does provide it, though there are some concerns (like workers pausing after their data got loaded). You can read PyTorch thread about it (not sure about it as I didn't verify on my own). Also, a lot of valuable insight provided by this comment and this blog post (though not sure how up to date those are).

All in all to substantially improve data loading you would need to get your hands quite dirty (or maybe there are libraries doing this some of those for PyTorch, if so,I would love to know about them).

Also remember to profile your changes, see torch.nn.bottleneck

EDIT: DALI project might be worth checking out, though AFAIK it has some problems with RAM memory growing linearly with number of epochs.

Upvotes: 26

Related Questions