Alexander Soare
Alexander Soare

Reputation: 3267

How to maintain state in a DataLoader's Dataset

I have something like: (see self.cache for the bit that's interesting).

class DescriptorDataset(torch.utils.data.Dataset):
    def __init__(self, descriptor_dir):
        super().__init__()
        self.file_paths = glob(osp.join(descriptor_dir, '*'))
        self.image_ids = [Path(fp).stem for fp in self.file_paths]
        self.cache = {}

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, ix):
        file_path = self.file_paths[ix]
        descriptor = self.get_descriptor(file_path)
        return descriptor, Path(file_path).stem

    def get_descriptor(self, file_path):
        descriptor = self.cache.get(file_path, torch.load(file_path))
        self.cache[file_path] = descriptor
        return descriptor

query_loader = torch.utils.data.DataLoader(
    DescriptorDataset(query_dir), batch_size=1, num_workers=0

I noticed that the caching mechanism works when num_workers == 0 but not for num_workers > 0. Does PyTorch have an inbuilt way to handle this?

Upvotes: 0

Views: 617

Answers (2)

Tevien
Tevien

Reputation: 131

When I have come across this situation, I have filled the cache during initialization. In that case, it remains fixed during training/inference and can be reloaded the next time you instantiate:


class DescriptorDataset(torch.utils.data.Dataset):
    def __init__(self, descriptor_dir, cache_loc=None):
        super().__init__()
        self.file_paths = glob(osp.join(descriptor_dir, '*'))
        self.image_ids = [Path(fp).stem for fp in self.file_paths]
        self.cache = self.make_cache(cache_loc)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, ix):
        file_path = self.file_paths[ix]
        descriptor = self.get_descriptor(file_path)
        return descriptor, Path(file_path).stem

    def get_descriptor(self, file_path):
        descriptor = self.cache.get(file_path, torch.load(file_path))
        self.cache[file_path] = descriptor
        return descriptor

    def make_cache(self, cache_loc):
        if os.path.exists(cache_loc):
            return joblib.load(cache_loc)
        else:
            cache = {}
            for p in self.file_paths:
                descriptor = torch.load(p)
                cache[p] = descriptor
        return cache

Upvotes: 1

edornd
edornd

Reputation: 471

Disclaimer: I am not an expert about the internal mechanisms of PyTorch's DataLoader.

However, here's my few cents: given that the DataLoader handles the __getitem__ calls using multiprocessing, I wouldn't exclude some weird race conditions. I suppose your file paths are unique, nevertheless I'd suggest to attempt the cache indexing using the same ix from the __getitem__ call, which is guaranteed to be a unique identifier for that item.

Something like this:

class DescriptorDataset(torch.utils.data.Dataset):
    def __init__(self, descriptor_dir):
        super().__init__()
        self.file_paths = glob(osp.join(descriptor_dir, '*'))
        self.image_ids = [Path(fp).stem for fp in self.file_paths]
        self.cache = {}

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, ix):
        file_path = self.file_paths[ix]
        descriptor = self.get_descriptor(ix, file_path)
        return descriptor, Path(file_path).stem

    def get_descriptor(self, ix, file_path):
        descriptor = self.cache.get(ix, torch.load(file_path))
        self.cache[ix] = descriptor
        return descriptor

Sidenote: the title is a bit misleading, you're not storing the state in the DataLoader right now, but still in the Dataset (I guess it was just an oversight).

Upvotes: 0

Related Questions