Reputation: 801
it's my first time approaching pytorch. I built a dataset class to load tensors by Dataloader, like this:
train_loader = DataLoader(dataset_train, batch_size=6, drop_last=True)
But at the following line:
for i,train_batch in enumerate(train_loader):
I receive this error: TypeError: __ getitem__() takes 1 positional argument but 2 were given
Any help would be great. I'm stuck on it. My concern is that it could depend on the libraries versions I'm using: matplotlib 3.5.2 numpy 1.23.0 opencv-python 4.6.0.66 torch 1.12.0 torch-tb-profiler 0.4.0 torchaudio 0.12.0 torchvision 0.13.0
Thank you.
Upvotes: 0
Views: 727
Reputation: 11
I believe the problem lies in how you define your __getitem__
function in your custom dataset class. Make sure your __getitem__
function takes idx
argument, like so:
def __getitem__(self, idx):
# your code
Upvotes: 1
Reputation: 801
I solved the problem by going into detail in the PyTorch documentation. I suggest to everyone with the same issue to go through the PyTorch classes in the GitHub code. Here is the datasets documentation: https://pytorch.org/vision/stable/datasets.html
Upvotes: 0
Reputation: 40728
I believe you expected to enumerate your dataloader:
for i, train_batch in enumerate(dataloader):
# train loop
Upvotes: 1