Reputation: 3545
I have a huge dataset that does not fit in memory (150G) and I'm looking for the best way to work with it in pytorch. The dataset is composed of several .npz
files of 10k samples each. I tried to build a Dataset
class
class MyDataset(Dataset):
def __init__(self, path):
self.path = path
self.files = os.listdir(self.path)
self.file_length = {}
for f in self.files:
# Load file in as a nmap
d = np.load(os.path.join(self.path, f), mmap_mode='r')
self.file_length[f] = len(d['y'])
def __len__(self):
raise NotImplementedException()
def __getitem__(self, idx):
# Find the file where idx belongs to
count = 0
f_key = ''
local_idx = 0
for k in self.file_length:
if count < idx < count + self.file_length[k]:
f_key = k
local_idx = idx - count
break
else:
count += self.file_length[k]
# Open file as numpy.memmap
d = np.load(os.path.join(self.path, f_key), mmap_mode='r')
# Actually fetch the data
X = np.expand_dims(d['X'][local_idx], axis=1)
y = np.expand_dims((d['y'][local_idx] == 2).astype(np.float32), axis=1)
return X, y
but when a sample is actually fetched, it takes more than 30s. It looks like the entire .npz
is opened, stocked in RAM and it accessed the right index.
How to be more efficient ?
EDIT
It appears to be a misunderstading of .npz
files see post, but is there a better approach ?
SOLUTION PROPOSAL
As proposed by @covariantmonkey, lmdb can be a good choice. For now, as the problem comes from .npz
files and not memmap
, I remodelled my dataset by splitting .npz
packages files into several .npy
files. I can now use the same logic where memmap
makes all sense and is really fast (several ms to load a sample).
Upvotes: 5
Views: 2945
Reputation: 223
How large are the individual .npz
files? I was in similar predicament a month ago. Various forum posts, google searches later I went the lmdb route. Here is what I did
key = filename
and data = np.savez_compressed(stff)
lmdb
takes care of the mmap for you and insanely fast to load.
Regards,
A
PS: savez_compessed
requires a byte object so you can do something like
output = io.BytesIO()
np.savez_compressed(output, x=your_np_data)
#cache output in lmdb
Upvotes: 3