Reputation: 98
I have a huge set of images(60k) and masks(60k) that need to be loaded into a PyTorch dataloader for semantic segmentation.
Directory Structure:
- Segmentation
-images
-color_left_trajectory_3000_00001.jpg
-color_left_trajectory_3000_00002.jpg
...
-masks
-color_segmentation_3000_00001.jpg
-color_segmentation_3000_00002.jpg
...
I want to know the most efficient way to load these into a dataloader in Pytorch. I was thinking of generating a csv file with the paths to images and masks. How will I go about generating the same? Any other suggestions are appreciated!
Upvotes: 0
Views: 357
Reputation: 521
I recommend that you make a custom subclass from the dataset class. In the init function, the paths to the images and masks are generated and then saved. This is an example:
import torch
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
class CustomData(Dataset):
def __init__(self,data_dir='Segmentation', data_transform=None,split= 'train'):
self.imgs = []
self.labels= []
self.transform = data_transform
self.data_dir = data_dir
#self.imgs_dir = os.path.join(data_dir, split, 'images')
#self.labels_dir = os.path.join(data_dir, split, 'labels')
self.imgs_dir = os.path.join(data_dir, 'images')
self.labels_dir = os.path.join(data_dir, 'labels')
for img_name in os.listdir(self.imgs_dir):
img_path = os.path.join(self.imgs_dir, img_name)
label_name = "color_segmentation_"+"_".join(img.split('.')[0].split('_')[-2:])+'.png'
label_path = os.path.join(self.labels_dir, label_name)
self.imgs.append(img_path)
self.labels.append(label_path)
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
img = Image.open(self.imgs[idx])
label = Image.open(self.labels[idx])
if self.transform is not None:
img, label = self.transform(img, label)
return img, label
class ToTensor:
def __call__(self, image, target=None):
image = F.to_tensor(image)
if target is not None:
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
if __name__ == '__main__':
data = CustomData(data_transform=ToTensor)
dataloader = DataLoader(data,batch_size=10)
Upvotes: 2