swooders
swooders

Reputation: 189

Pytorch default dataloader gets stuck for large image classification training set

I am training image classification models in Pytorch and using their default data loader to load my training data. I have a very large training dataset, so usually a couple thousand sample images per class. I've trained models with about 200k images total without issues in the past. However I've found that when have over a million images in total, the Pytorch data loader get stuck.

I believe the code is hanging when I call datasets.ImageFolder(...). When I Ctrl-C, this is consistently the output:

Traceback (most recent call last):                                                                                                 │
  File "main.py", line 412, in <module>                                                                                            │
    main()                                                                                                                         │
  File "main.py", line 122, in main                                                                                                │
    run_training(args.group, args.num_classes)                                                                                     │
  File "main.py", line 203, in run_training                                                                                        │
    train_loader = create_dataloader(traindir, tfm.train_trans, shuffle=True)                                                      │
  File "main.py", line 236, in create_dataloader                                                                                   │
    dataset = datasets.ImageFolder(directory, trans)                                                                               │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 209, in __init__     │
    is_valid_file=is_valid_file)                                                                                                   │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 94, in __init__      │
    samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)                                                     │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 47, in make_dataset  │
    for root, _, fnames in sorted(os.walk(d)):                                                                                     │
  File "/usr/lib/python3.5/os.py", line 380, in walk                                                                               │
    is_dir = entry.is_dir()                                                                                                        │
Keyboard Interrupt                                                                                                                       

I thought there might be a deadlock somewhere, however based off the stack output from Ctrl-C it doesn't look like its waiting on a lock. So then I thought that the dataloader was just slow because I was trying to load a lot more data. I let it run for about 2 days and it didn't make any progress, and in the last 2 hours of loading I checked the amount of RAM usage stayed the same. I also have been able to load training datasets with over 200k images in less than a couple hours in the past. I also tried upgrading my GCP machine to have 32 cores, 4 GPUs, and over 100GB in RAM, however it seems to be that after a certain amount of memory is loaded the data loader just gets stuck.

I'm confused how the data loader could be getting stuck while looping through the directory, and I'm still unsure if its stuck or just extremely slow. Is there some way I can change the Pytortch dataloader to be able to handle 1million+ images for training? Any debugging suggestions are also appreciated!

Thank you!

Upvotes: 3

Views: 5337

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24691

It's not a problem with DataLoader, it's a problem with torchvision.datasets.ImageFolder and how it works (and why it works much much worse the more data you have).

It hangs on this line, as indicated by your error:

for root, _, fnames in sorted(os.walk(d)): 

Source can be found here.

Underlying problem is it keeps each path and corresponding label in giant list, see the code below (a few things removed for brevity):

def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
    images = []
    dir = os.path.expanduser(dir)
    # Iterate over all subfolders which were found previously
    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target) # Create path to this subfolder
        # Assuming it is directory (which usually is the case)
        for root, _, fnames in sorted(os.walk(d, followlinks=True)):
            # Iterate over ALL files in this subdirectory
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                # Assuming it is correctly recognized as image file
                item = (path, class_to_idx[target])
                # Add to path with all images
                images.append(item)

    return images

Obviously images will contain 1 million strings (quite lengthy as well) and corresponding int for the classes which definitely is a lot and depends on RAM and CPU.

You can create your own datasets though (provided you change names of your images beforehand) so no memory will be occupied by the dataset.

Setup data structure

Your folder structure should look like this:

root
    class1
    class2
    class3
    ...

Use how many classes you have/need.

Now each class should have the following data:

class1
    0.png
    1.png
    2.png
    ...

Given that you can move on to creating datasets.

Create Datasets

Below torch.utils.data.Dataset uses PIL to open images, you could do it in another way though:

import os
import pathlib

import torch
from PIL import Image


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, root: str, folder: str, klass: int, extension: str = "png"):
        self._data = pathlib.Path(root) / folder
        self.klass = klass
        self.extension = extension
        # Only calculate once how many files are in this folder
        # Could be passed as argument if you precalculate it somehow
        # e.g. ls | wc -l on Linux
        self._length = sum(1 for entry in os.listdir(self._data))

    def __len__(self):
        # No need to recalculate this value every time
        return self._length

    def __getitem__(self, index):
        # images always follow [0, n-1], so you access them directly
        return Image.open(self._data / "{}.{}".format(str(index), self.extension))

Now you can create your datasets easily (folder structure assumed like the one above:

root = "/path/to/root/with/images"
dataset = (
    ImageDataset(root, "class0", 0)
    + ImageDataset(root, "class1", 1)
    + ImageDataset(root, "class2", 2)
)

You could add as many datasets with specified classes as you wish, do it in loop or whatever.

Finally, use torch.utils.data.DataLoader as per usual, e.g.:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

Upvotes: 6

Related Questions