Reputation: 25
I have a dataset containing images as inputs and labels/targets as images as well. The structure in the folder is as follows:
> DATASET/
> ---TRAIN/
> ------image_xx.png
> ------label_xx.png
> ---TEST/
> ------image_xx.png
> ------label_xx.png
I've currently tried to use "ImageFolder" from torchvisions datasets to load the images as follows:
TRAIN_PATH = '/path/to/dataset/DATASET'
train_data = datasets.ImageFolder(root=TRAIN_PATH, transform=transforms.ToTensor())
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
However as shown below:
for img, label in train_loader:
print(img.shape)
print(label.shape)
break
torch.Size([16, 3, 128, 128])
torch.Size([16])
The labels aren't images but rather indicies or something similar. Is there a convenient way of importing this dataset with the aforementioned structure?
Upvotes: 1
Views: 2321
Reputation: 8527
The ImageFolder
dataset is suitable when you have discrete, scalar classes for each image. It expects the directory structure to be such that each subdirectory contains a certain class.
For your case, you can simply define your own subclass of torch.nn.Dataset
. This tutorial may be helpful.
A simple example (I have not tried running it to see if it works correctly):
import torch
import os
import cv2
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root_path, transform=None):
self.data_paths = [f for f in sorted(os.listdir(root_path)) if f.startswith("image")]
self.label_paths = [f for f in sorted(os.listdir(root_path)) if f.startswith("label")]
self.transform = transform
def __getitem__(self, idx):
img = cv2.imread(self.data_paths[idx])
label = cv2.imread(self.label_paths[idx])
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.data_paths)
TRAIN_PATH = '/path/to/dataset/DATASET/TRAIN/'
train_data = MyDataset(root_path=TRAIN_PATH, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)
Upvotes: 2