Reputation: 137
I want to get 3 batches of images from 3 different folders. I have written custom data loader in pytorch. but it is returning list that has all the batches instead of single batch at a time.(running in google colab)
#custom data loader
class set(Dataset):
def __init__(self, dataset_input, dataset_expertA, dataset_expertB):
self.dataset1 = dataset_input
self.dataset2 = dataset_expertA
self.dataset3 = dataset_expertB
def __getitem__(self, index):
x1 = self.dataset1[index]
x2 = self.dataset2[index]
x3 = self.dataset3[index]
return x1, x2, x3
def __len__(self):
return len(self.dataset1)
input_path = "/content/gdrive/My Drive/project/input/"
dataset = datasets.ImageFolder(root= input_path, transform=transforms.Compose([
transforms.Resize([64,64]),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
expertA_path = "/content/gdrive/My Drive/project/expertA/"
datasetA = datasets.ImageFolder(root= expertA_path, transform=transforms.Compose([
transforms.Resize([64,64]),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
expertB_path = "/content/gdrive/My Drive/project/expertB/"
datasetB = datasets.ImageFolder(root= expertB_path, transform=transforms.Compose([
transforms.Resize([64,64]),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
data = set(dataset, datasetA, datasetB)
dataloader = torch.utils.data.DataLoader(data, batch_size=64,
shuffle=True, num_workers=2)
for i, (inp, expA, expB) in enumerate(dataloader):
print(inp.shape)
break
this prints error that inp is list and when i print(inp[0].shape) i get proper shape i think inp contains all batches ie inp[0], inp[1]...
what mistake am i doing in data loader code?
Upvotes: 1
Views: 1303
Reputation: 32972
datasets.ImageFolder
returns a tuple of (image, label), hence inp
is also a tuple, where inp[0]
are the images and inp[1]
their corresponding labels. The same applies to expA
and expB
.
If you only want the images without the labels, you can ignore the labels and just return the images when accessing the data in your custom dataset:
def __getitem__(self, index):
image1, label1 = self.dataset1[index]
image2, label2 = self.dataset2[index]
image3, label3 = self.dataset3[index]
return image1, image2, image3
Upvotes: 3