Reputation: 123
I have two tensors:
x[train], y[train]
And the shape is
(311, 3, 224, 224), (311) # 311 Has No Information
I want to use DataLoader to load them batch by batch, the code I write is:
from torch.utils.data import Dataset
class KD_Train(Dataset):
def __init__(self,a,b):
self.imgs = a
self.index = b
def __len__(self):
return len(self.imgs)
def __getitem__(self,index):
return self.imgs, self.index
kdt = KD_Train(x[train], y[train])
train_data_loader = Data.DataLoader(
kdt,
batch_size = 64,
shuffle = True,
num_workers = 0)
for step, (a,b) in enumerate (train_data_loader):
print(a.shape)
break
But it shows:
(64, 311, 3, 224, 224)
the DataLoader just add a dimension directly instead of choosing some batches, anyone know what should I do?
Upvotes: 0
Views: 2912
Reputation: 40618
Your dataset's __getitem__
method should return a single element:
def __getitem__(self, index):
return self.imgs[index], self.index[index]
Upvotes: 2