prateek k
prateek k

Reputation: 137

Custom data loader is returning list in pytorch

I want to get 3 batches of images from 3 different folders. I have written custom data loader in pytorch. but it is returning list that has all the batches instead of single batch at a time.(running in google colab)

#custom data loader
class set(Dataset):
    def __init__(self, dataset_input, dataset_expertA, dataset_expertB):
        self.dataset1 = dataset_input
        self.dataset2 = dataset_expertA
        self.dataset3 = dataset_expertB

    def __getitem__(self, index):
        x1 = self.dataset1[index]
        x2 = self.dataset2[index]
        x3 = self.dataset3[index]

        return x1, x2, x3

    def __len__(self):
        return len(self.dataset1)

input_path = "/content/gdrive/My Drive/project/input/"

dataset = datasets.ImageFolder(root= input_path, transform=transforms.Compose([
                               transforms.Resize([64,64]),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ]))

expertA_path = "/content/gdrive/My Drive/project/expertA/"

datasetA = datasets.ImageFolder(root= expertA_path, transform=transforms.Compose([
                               transforms.Resize([64,64]),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ]))


expertB_path = "/content/gdrive/My Drive/project/expertB/"

datasetB = datasets.ImageFolder(root= expertB_path, transform=transforms.Compose([
                               transforms.Resize([64,64]),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ]))


data = set(dataset, datasetA, datasetB)
dataloader = torch.utils.data.DataLoader(data, batch_size=64,
                                         shuffle=True, num_workers=2)


for i, (inp, expA, expB) in enumerate(dataloader):

  print(inp.shape)
  break

this prints error that inp is list and when i print(inp[0].shape) i get proper shape i think inp contains all batches ie inp[0], inp[1]...

what mistake am i doing in data loader code?

Upvotes: 1

Views: 1303

Answers (1)

Michael Jungo
Michael Jungo

Reputation: 32972

datasets.ImageFolder returns a tuple of (image, label), hence inp is also a tuple, where inp[0] are the images and inp[1] their corresponding labels. The same applies to expA and expB.

If you only want the images without the labels, you can ignore the labels and just return the images when accessing the data in your custom dataset:

def __getitem__(self, index):
    image1, label1 = self.dataset1[index]
    image2, label2 = self.dataset2[index]
    image3, label3 = self.dataset3[index]

    return image1, image2, image3

Upvotes: 3

Related Questions