Reputation: 189
I am training image classification models in Pytorch and using their default data loader to load my training data. I have a very large training dataset, so usually a couple thousand sample images per class. I've trained models with about 200k images total without issues in the past. However I've found that when have over a million images in total, the Pytorch data loader get stuck.
I believe the code is hanging when I call datasets.ImageFolder(...)
. When I Ctrl-C, this is consistently the output:
Traceback (most recent call last): │
File "main.py", line 412, in <module> │
main() │
File "main.py", line 122, in main │
run_training(args.group, args.num_classes) │
File "main.py", line 203, in run_training │
train_loader = create_dataloader(traindir, tfm.train_trans, shuffle=True) │
File "main.py", line 236, in create_dataloader │
dataset = datasets.ImageFolder(directory, trans) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 209, in __init__ │
is_valid_file=is_valid_file) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 94, in __init__ │
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 47, in make_dataset │
for root, _, fnames in sorted(os.walk(d)): │
File "/usr/lib/python3.5/os.py", line 380, in walk │
is_dir = entry.is_dir() │
Keyboard Interrupt
I thought there might be a deadlock somewhere, however based off the stack output from Ctrl-C it doesn't look like its waiting on a lock. So then I thought that the dataloader was just slow because I was trying to load a lot more data. I let it run for about 2 days and it didn't make any progress, and in the last 2 hours of loading I checked the amount of RAM usage stayed the same. I also have been able to load training datasets with over 200k images in less than a couple hours in the past. I also tried upgrading my GCP machine to have 32 cores, 4 GPUs, and over 100GB in RAM, however it seems to be that after a certain amount of memory is loaded the data loader just gets stuck.
I'm confused how the data loader could be getting stuck while looping through the directory, and I'm still unsure if its stuck or just extremely slow. Is there some way I can change the Pytortch dataloader to be able to handle 1million+ images for training? Any debugging suggestions are also appreciated!
Thank you!
Upvotes: 3
Views: 5337
Reputation: 24691
It's not a problem with DataLoader
, it's a problem with torchvision.datasets.ImageFolder
and how it works (and why it works much much worse the more data you have).
It hangs on this line, as indicated by your error:
for root, _, fnames in sorted(os.walk(d)):
Source can be found here.
Underlying problem is it keeps each path
and corresponding label
in giant list
, see the code below (a few things removed for brevity):
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
# Iterate over all subfolders which were found previously
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target) # Create path to this subfolder
# Assuming it is directory (which usually is the case)
for root, _, fnames in sorted(os.walk(d, followlinks=True)):
# Iterate over ALL files in this subdirectory
for fname in sorted(fnames):
path = os.path.join(root, fname)
# Assuming it is correctly recognized as image file
item = (path, class_to_idx[target])
# Add to path with all images
images.append(item)
return images
Obviously images will contain 1 million strings (quite lengthy as well) and corresponding int
for the classes which definitely is a lot and depends on RAM and CPU.
You can create your own datasets though (provided you change names of your images beforehand) so no memory will be occupied by the dataset
.
Your folder structure should look like this:
root
class1
class2
class3
...
Use how many classes you have/need.
Now each class
should have the following data:
class1
0.png
1.png
2.png
...
Given that you can move on to creating datasets.
Below torch.utils.data.Dataset
uses PIL
to open images, you could do it in another way though:
import os
import pathlib
import torch
from PIL import Image
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, root: str, folder: str, klass: int, extension: str = "png"):
self._data = pathlib.Path(root) / folder
self.klass = klass
self.extension = extension
# Only calculate once how many files are in this folder
# Could be passed as argument if you precalculate it somehow
# e.g. ls | wc -l on Linux
self._length = sum(1 for entry in os.listdir(self._data))
def __len__(self):
# No need to recalculate this value every time
return self._length
def __getitem__(self, index):
# images always follow [0, n-1], so you access them directly
return Image.open(self._data / "{}.{}".format(str(index), self.extension))
Now you can create your datasets easily (folder structure assumed like the one above:
root = "/path/to/root/with/images"
dataset = (
ImageDataset(root, "class0", 0)
+ ImageDataset(root, "class1", 1)
+ ImageDataset(root, "class2", 2)
)
You could add as many datasets
with specified classes as you wish, do it in loop or whatever.
Finally, use torch.utils.data.DataLoader
as per usual, e.g.:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
Upvotes: 6