Reputation: 278
I'm training a Neural Network and have overall > 15GB of data inside a folder, the folder has multiple pickle files, and each file contains two lists that each holds multiple values. This looks like the following: dataset_folder:\
Each file_*.pickle contains a variable length list (list x and list y).
How to load all the data to train the model without having memory issue?
Upvotes: 3
Views: 207
Reputation: 278
By implementing the custom dataset class provided from Pytorch, we need to implement three methods so pytorch loader can work with your data
__len__
__getitem__
__init__
Let's go through how to implement each one of them seperatly.
__init__
def __init__(self):
# Original Data has the following format
"""
dict_object =
{
"x":[],
"y":[]
}
"""
DIRECTORY = "data/raw"
self.dataset_file_name = os.listdir(DIRECTORY)
self.dataset_file_name_index = 0
self.dataset_length =0
self.prefix_sum_idx = list()
# Loop over each file and calculate the length of overall dataset
# you might need to check if file_name is file
for file_name in os.listdir(DIRECTORY):
with (open(f'{DIRECTORY}/{file_name}', "rb")) as openfile:
dict_object = pickle.load(openfile)
curr_page_sum = len(dict_object["x"]) + len(dict_object["y"])
self.prefix_sum_idx.append(curr_page_sum)
self.dataset_length += curr_page_sum
# prefix sum so we have an idea of where each index appeared in which file.
for i in range (1,len(self.prefix_sum_idx)):
self.prefix_sum_idx[i] = self.prefix_sum_idx[i] + self.prefix_sum_idx[i-1]
assert self.prefix_sum_idx[-1] == self.dataset_length
self.x = []
self.y = []
As you can see above, the main idea is to use prefix sum to "treat" all the dataset as once, so the logic is whenever you need to get access to a specific index later, you simply look into prefix_sum_idx
to see this where this idx appear.
In the image above, let's say we need to access the index 150. Thanks to prefix sum, we are now able to know that 150 exist in the second .pickle
file. Still we need a fast mechanism to know where that idx
exist in the prefix_sum_idx
. This will be explained in the __getitem__
__getitem__
def read_pickle_file(self, idx):
file_name = self.dataset_file_name[idx]
dict_object = dict()
with (open(f'{YOUR_DIRECTORY}/{file_name}', "rb")) as openfile:
dict_object = pickle.load(openfile)
self.x = dict_object['x']
self.y = #some logic here
......
# Some logic here....
def __getitem__(self,idx):
# Similar to C++ std::upper_bound - O(log n)
temp = bisect.bisect_right(self.prefix_sum_idx, idx)
self.read_pickle_file(temp)
local_idx = idx - self.prefix_sum_idx[temp]
return self.x[local_idx],self.y[local_idx]
check bisect_right() docs for details on how it works, but simply it returns the rightmost place in the sorted list to insert the given element and keep it sorted. In our approach, we're interested only in the following question, "which file should I access in order to get the appropriate data". More importantly, it does so in O(log n)
__len__
def __len__(self):
return self.dataset_length
In order to get the length of our dataset, we loop through each file in and accumulate the results as shown in __init__
.
The full code sample goes like this:
import pickle
import torch
import torch.nn as nn
import numpy
import os
import bisect
from torch.utils.data import Dataset, DataLoader
from src.data.make_dataset import main
from torch.nn import functional as F
class dataset(Dataset):
def __init__(self):
# Original Data has the following format
"""
dict_object =
{
"x":[],
"y":[]
}
"""
DIRECTORY = "data/raw"
self.dataset_file_name = os.listdir(DIRECTORY)
self.dataset_file_name_index = 0
self.dataset_length =0
self.prefix_sum_idx = list()
# Loop over each file and calculate the length of overall dataset
# you might need to check if file_name is file
for file_name in os.listdir(DIRECTORY):
with (open(f'{DIRECTORY}/{file_name}', "rb")) as openfile:
dict_object = pickle.load(openfile)
curr_page_sum = len(dict_object["x"]) + len(dict_object["y"])
self.prefix_sum_idx.append(curr_page_sum)
self.dataset_length += curr_page_sum
# prefix sum so we have an idea of where each index appeared in which file.
for i in range (1,len(self.prefix_sum_idx)):
self.prefix_sum_idx[i] = self.prefix_sum_idx[i] + self.prefix_sum_idx[i-1]
assert self.prefix_sum_idx[-1] == self.dataset_length
self.x = []
self.y = []
def read_pickle_file(self, idx):
file_name = self.dataset_file_name[idx]
dict_object = dict()
with (open(f'{YOUR_DIRECTORY}/{file_name}', "rb")) as openfile:
dict_object = pickle.load(openfile)
self.x = dict_object['x']
self.y = #some logic here
......
# Some logic here....
def __getitem__(self,idx):
# Similar to C++ std::upper_bound - O(log n)
temp = bisect.bisect_right(self.prefix_sum_idx, idx)
self.read_pickle_file(temp)
local_idx = idx - self.prefix_sum_idx[temp]
return self.x[local_idx],self.y[local_idx]
def __len__(self):
return self.dataset_length
large_dataset = dataset()
train_size = int (0.8 * len(large_dataset))
validation_size = len(large_dataset) - train_size
train_dataset, validation_dataset = torch.utils.data.random_split(large_dataset, [train_size, validation_size])
validation_loader = DataLoader(validation_dataset, batch_size=64, num_workers=4, shuffle=False)
train_loader = DataLoader(train_dataset,batch_size=64, num_workers=4,shuffle=False)
Upvotes: 3