Reputation: 3267
I have something like: (see self.cache
for the bit that's interesting).
class DescriptorDataset(torch.utils.data.Dataset):
def __init__(self, descriptor_dir):
super().__init__()
self.file_paths = glob(osp.join(descriptor_dir, '*'))
self.image_ids = [Path(fp).stem for fp in self.file_paths]
self.cache = {}
def __len__(self):
return len(self.file_paths)
def __getitem__(self, ix):
file_path = self.file_paths[ix]
descriptor = self.get_descriptor(file_path)
return descriptor, Path(file_path).stem
def get_descriptor(self, file_path):
descriptor = self.cache.get(file_path, torch.load(file_path))
self.cache[file_path] = descriptor
return descriptor
query_loader = torch.utils.data.DataLoader(
DescriptorDataset(query_dir), batch_size=1, num_workers=0
I noticed that the caching mechanism works when num_workers == 0
but not for num_workers > 0
. Does PyTorch have an inbuilt way to handle this?
Upvotes: 0
Views: 617
Reputation: 131
When I have come across this situation, I have filled the cache during initialization. In that case, it remains fixed during training/inference and can be reloaded the next time you instantiate:
class DescriptorDataset(torch.utils.data.Dataset):
def __init__(self, descriptor_dir, cache_loc=None):
super().__init__()
self.file_paths = glob(osp.join(descriptor_dir, '*'))
self.image_ids = [Path(fp).stem for fp in self.file_paths]
self.cache = self.make_cache(cache_loc)
def __len__(self):
return len(self.file_paths)
def __getitem__(self, ix):
file_path = self.file_paths[ix]
descriptor = self.get_descriptor(file_path)
return descriptor, Path(file_path).stem
def get_descriptor(self, file_path):
descriptor = self.cache.get(file_path, torch.load(file_path))
self.cache[file_path] = descriptor
return descriptor
def make_cache(self, cache_loc):
if os.path.exists(cache_loc):
return joblib.load(cache_loc)
else:
cache = {}
for p in self.file_paths:
descriptor = torch.load(p)
cache[p] = descriptor
return cache
Upvotes: 1
Reputation: 471
Disclaimer: I am not an expert about the internal mechanisms of PyTorch's DataLoader.
However, here's my few cents: given that the DataLoader handles the __getitem__
calls using multiprocessing, I wouldn't exclude some weird race conditions.
I suppose your file paths are unique, nevertheless I'd suggest to attempt the cache indexing using the same ix
from the __getitem__
call, which is guaranteed to be a unique identifier for that item.
Something like this:
class DescriptorDataset(torch.utils.data.Dataset):
def __init__(self, descriptor_dir):
super().__init__()
self.file_paths = glob(osp.join(descriptor_dir, '*'))
self.image_ids = [Path(fp).stem for fp in self.file_paths]
self.cache = {}
def __len__(self):
return len(self.file_paths)
def __getitem__(self, ix):
file_path = self.file_paths[ix]
descriptor = self.get_descriptor(ix, file_path)
return descriptor, Path(file_path).stem
def get_descriptor(self, ix, file_path):
descriptor = self.cache.get(ix, torch.load(file_path))
self.cache[ix] = descriptor
return descriptor
Sidenote: the title is a bit misleading, you're not storing the state in the DataLoader right now, but still in the Dataset (I guess it was just an oversight).
Upvotes: 0