Kevin Sun
Kevin Sun

Reputation: 462

How to speed up the "ImageFolder" for ImageNet

I am in an university, and all the file system are in a remote system, wherever I log in with my account, I could aways access my home directory. even though I log into the GPU servers through SSH command. This is the condition where I employ the GPU servers to read data.

Currently, I use the PyTorch to train ResNet from scratch on ImageNet, my codes only use all the GPUs in the same computer, I found that the "torchvision.datasets.ImageFolder" will take almost two hours.

Would you please provide some experiences in how to speed up "torchvision.datasets.ImageFolder"? Thanks very much.

Upvotes: 5

Views: 4652

Answers (2)

Yann Dubois
Yann Dubois

Reputation: 1345

If you are sure that the folder structure doesn't change you can cache the structure (not the data which is too large) using the following:


import json
from functools import wraps
from torchvision.datasets import ImageNet

def file_cache(filename):
    """Decorator to cache the output of a function to disk."""
    def decorator(f):
        @wraps(f)
        def decorated(self, directory, *args, **kwargs):
            filepath = Path(directory) / filename
            if filepath.is_file():
                out = json.loads(filepath.read_text())
            else:
                out = f(self, directory, *args, **kwargs)
                filepath.write_text(json.dumps(out))
            return out
        return decorated
    return decorator

class CachedImageNet(ImageNet):
    @file_cache(filename="cached_classes.json")
    def find_classes(self, directory, *args, **kwargs):
        classes = super().find_classes(directory, *args, **kwargs)
        return classes

    @file_cache(filename="cached_structure.json")
    def make_dataset(self, directory, *args, **kwargs):
        dataset = super().make_dataset(directory, *args, **kwargs)
        return dataset

Upvotes: 0

Shai
Shai

Reputation: 114866

Why it takes so long?
Setting up an ImageFolder can take a long time, especially when the images are stored on a slow remote disk. The reason for this latency is that the __init__ function for the dataset goes over all files in the image folders and check whether this file is an image file. For ImageNet that can take quite a while as there are over 1 million files to check.

What can you do?
- As Kevin Sun already pointed out, copying the dataset to a local (and possibly much faster) storage can significantly speed up things.
- Alternatively, you can create a modified dataset class that does not read all the files, but relies on a cached list of files - a cached list that you prepare only once in advance and to be used for all runs.

Upvotes: 1

Related Questions