Mamen
Mamen

Reputation: 1436

Pytorch Dataloader not spliting data into batch

I have dataset class like this:

class LoadDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label
    def __len__(self):
        dlen = len(self.data)
        return dlen
    def __getitem__(self, index):
        return self.data, self.label

then i load my image dataset which have [485, 1, 32, 32] shape

train_dataset = LoadDataset(xtrain, ytrain)
print(len(train_dataset))
# output 485

then i load the data with DataLoader

train_loader = DataLoader(train_dataset, batch_size=32)

and then i iterate the data:

for epoch in range(num_epoch):
        for inputs, labels in train_loader:   
            print(inputs.shape)

the output prints torch.Size([32, 485, 1, 32, 32]), it should be torch.Size([32, 1, 32, 32]),

Can anyone help me?

Upvotes: 0

Views: 54

Answers (1)

wyhn.w
wyhn.w

Reputation: 201

The __getitem__ method should return 1 data piece, you returned all of them.

Try this:

class LoadDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label
    def __len__(self):
        dlen = len(self.data)
        llen = len(self.label)  # different here
        return min(dlen, llen)  # different here
    def __getitem__(self, index):
        return self.data[index], self.label[index]  # different here

Upvotes: 1

Related Questions