GuillaumeA
GuillaumeA

Reputation: 3545

How to work with large dataset in pytorch

I have a huge dataset that does not fit in memory (150G) and I'm looking for the best way to work with it in pytorch. The dataset is composed of several .npz files of 10k samples each. I tried to build a Dataset class

class MyDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.files = os.listdir(self.path)
        self.file_length = {}
        for f in self.files:
            # Load file in as a nmap
            d = np.load(os.path.join(self.path, f), mmap_mode='r')
            self.file_length[f] = len(d['y'])

    def __len__(self):
        raise NotImplementedException()

    def __getitem__(self, idx):                
        # Find the file where idx belongs to
        count = 0
        f_key = ''
        local_idx = 0
        for k in self.file_length:
            if count < idx < count + self.file_length[k]:
                f_key = k
                local_idx = idx - count
                break
            else:
                count += self.file_length[k]
        # Open file as numpy.memmap
        d = np.load(os.path.join(self.path, f_key), mmap_mode='r')
        # Actually fetch the data
        X = np.expand_dims(d['X'][local_idx], axis=1)
        y = np.expand_dims((d['y'][local_idx] == 2).astype(np.float32), axis=1)
        return X, y

but when a sample is actually fetched, it takes more than 30s. It looks like the entire .npz is opened, stocked in RAM and it accessed the right index. How to be more efficient ?

EDIT

It appears to be a misunderstading of .npz files see post, but is there a better approach ?

SOLUTION PROPOSAL

As proposed by @covariantmonkey, lmdb can be a good choice. For now, as the problem comes from .npz files and not memmap, I remodelled my dataset by splitting .npz packages files into several .npy files. I can now use the same logic where memmap makes all sense and is really fast (several ms to load a sample).

Upvotes: 5

Views: 2945

Answers (1)

covariantmonkey
covariantmonkey

Reputation: 223

How large are the individual .npz files? I was in similar predicament a month ago. Various forum posts, google searches later I went the lmdb route. Here is what I did

  1. Chunk the large dataset into small enough files that I can fit in gpu — each of them is essentially my minibatch. I did not optimize for load time at this stage just memory.
  2. create an lmdb index with key = filename and data = np.savez_compressed(stff)

lmdb takes care of the mmap for you and insanely fast to load.

Regards,
A

PS: savez_compessed requires a byte object so you can do something like

output = io.BytesIO()
np.savez_compressed(output, x=your_np_data)
#cache output in lmdb

Upvotes: 3

Related Questions