Bshr Eldebuch
Bshr Eldebuch

Reputation: 278

Working with large multiple datasets where each dataset contains multiple values - Pytorch

I'm training a Neural Network and have overall > 15GB of data inside a folder, the folder has multiple pickle files, and each file contains two lists that each holds multiple values. This looks like the following: dataset_folder:\

Each file_*.pickle contains a variable length list (list x and list y).

How to load all the data to train the model without having memory issue?

Upvotes: 3

Views: 207

Answers (1)

Bshr Eldebuch
Bshr Eldebuch

Reputation: 278

By implementing the custom dataset class provided from Pytorch, we need to implement three methods so pytorch loader can work with your data

  • __len__
  • __getitem__
  • __init__

Let's go through how to implement each one of them seperatly.

  • __init__

    def __init__(self):
    
     # Original Data has the following format
     """
       dict_object = 
       {
         "x":[],
         "y":[]
       }
     """
     DIRECTORY = "data/raw"
     self.dataset_file_name = os.listdir(DIRECTORY)
     self.dataset_file_name_index = 0
     self.dataset_length =0
     self.prefix_sum_idx = list()
     # Loop over each file and calculate the length of overall dataset
     # you might need to check if file_name is file
     for file_name in os.listdir(DIRECTORY):
       with (open(f'{DIRECTORY}/{file_name}', "rb")) as openfile:
         dict_object = pickle.load(openfile)
         curr_page_sum = len(dict_object["x"]) + len(dict_object["y"])
         self.prefix_sum_idx.append(curr_page_sum)
         self.dataset_length += curr_page_sum
     # prefix sum so we have an idea of where each index appeared in which file. 
     for i in range (1,len(self.prefix_sum_idx)):
       self.prefix_sum_idx[i] = self.prefix_sum_idx[i] + self.prefix_sum_idx[i-1]
    
     assert self.prefix_sum_idx[-1] == self.dataset_length
     self.x = []
     self.y = []
    

As you can see above, the main idea is to use prefix sum to "treat" all the dataset as once, so the logic is whenever you need to get access to a specific index later, you simply look into prefix_sum_idx to see this where this idx appear.

prefix sum illustration

In the image above, let's say we need to access the index 150. Thanks to prefix sum, we are now able to know that 150 exist in the second .pickle file. Still we need a fast mechanism to know where that idx exist in the prefix_sum_idx. This will be explained in the __getitem__

  • __getitem__

    def read_pickle_file(self, idx):
     file_name = self.dataset_file_name[idx]
     dict_object = dict()
     with (open(f'{YOUR_DIRECTORY}/{file_name}', "rb")) as openfile:
         dict_object = pickle.load(openfile)
    
     self.x = dict_object['x']
     self.y = #some logic here
     ......
     # Some logic here....
    
    
    def __getitem__(self,idx):
    
     # Similar to C++ std::upper_bound - O(log n)
     temp = bisect.bisect_right(self.prefix_sum_idx, idx)
    
     self.read_pickle_file(temp)
     local_idx = idx - self.prefix_sum_idx[temp] 
    
     return self.x[local_idx],self.y[local_idx]
    

check bisect_right() docs for details on how it works, but simply it returns the rightmost place in the sorted list to insert the given element and keep it sorted. In our approach, we're interested only in the following question, "which file should I access in order to get the appropriate data". More importantly, it does so in O(log n)

  • __len__

    def __len__(self):
     return self.dataset_length
    

In order to get the length of our dataset, we loop through each file in and accumulate the results as shown in __init__.

The full code sample goes like this:

import pickle
import torch
import torch.nn as nn
import numpy
import os 
import bisect
from torch.utils.data import Dataset, DataLoader
from src.data.make_dataset import main
from torch.nn import functional as F

class dataset(Dataset):
  def __init__(self):

    # Original Data has the following format
    """
    dict_object = 
    {
        "x":[],
        "y":[]
    }
    """
    DIRECTORY = "data/raw"
    self.dataset_file_name = os.listdir(DIRECTORY)
    self.dataset_file_name_index = 0
    self.dataset_length =0
    self.prefix_sum_idx = list()
    # Loop over each file and calculate the length of overall dataset
    # you might need to check if file_name is file
    for file_name in os.listdir(DIRECTORY):
    with (open(f'{DIRECTORY}/{file_name}', "rb")) as openfile:
        dict_object = pickle.load(openfile)
        curr_page_sum = len(dict_object["x"]) + len(dict_object["y"])
        self.prefix_sum_idx.append(curr_page_sum)
        self.dataset_length += curr_page_sum
    # prefix sum so we have an idea of where each index appeared in which file. 
    for i in range (1,len(self.prefix_sum_idx)):
    self.prefix_sum_idx[i] = self.prefix_sum_idx[i] + self.prefix_sum_idx[i-1]

    assert self.prefix_sum_idx[-1] == self.dataset_length
    self.x = []
    self.y = []

    



def read_pickle_file(self, idx):
 file_name = self.dataset_file_name[idx]
 dict_object = dict()
 with (open(f'{YOUR_DIRECTORY}/{file_name}', "rb")) as openfile:
     dict_object = pickle.load(openfile)

 self.x = dict_object['x']
 self.y = #some logic here
 ......
 # Some logic here....


def __getitem__(self,idx):

 # Similar to C++ std::upper_bound - O(log n)
 temp = bisect.bisect_right(self.prefix_sum_idx, idx)

 self.read_pickle_file(temp)
 local_idx = idx - self.prefix_sum_idx[temp] 

 return self.x[local_idx],self.y[local_idx]



def __len__(self):
 return self.dataset_length


large_dataset = dataset()
train_size = int (0.8 * len(large_dataset))
validation_size = len(large_dataset) - train_size

train_dataset, validation_dataset = torch.utils.data.random_split(large_dataset, [train_size, validation_size])
validation_loader = DataLoader(validation_dataset, batch_size=64, num_workers=4, shuffle=False)
train_loader = DataLoader(train_dataset,batch_size=64, num_workers=4,shuffle=False)

Upvotes: 3

Related Questions