Tim
Tim

Reputation: 123

Batchsize in DataLoader

I have two tensors:

x[train], y[train]

And the shape is

(311, 3, 224, 224), (311) # 311 Has No Information

I want to use DataLoader to load them batch by batch, the code I write is:

from torch.utils.data import Dataset

class KD_Train(Dataset):

    def __init__(self,a,b):
        self.imgs = a
        self.index = b

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self,index):
        return self.imgs, self.index

kdt = KD_Train(x[train], y[train])

train_data_loader = Data.DataLoader(
    kdt,
    batch_size = 64,
    shuffle = True,
    num_workers = 0)

for step, (a,b) in enumerate (train_data_loader):
    print(a.shape)
    break

But it shows:

(64, 311, 3, 224, 224)

the DataLoader just add a dimension directly instead of choosing some batches, anyone know what should I do?

Upvotes: 0

Views: 2912

Answers (1)

Ivan
Ivan

Reputation: 40618

Your dataset's __getitem__ method should return a single element:

def __getitem__(self, index):
    return self.imgs[index], self.index[index]

Upvotes: 2

Related Questions