Saeed
Saeed

Reputation: 718

Why iterating over a Pytorch dataloader never ends and goes forever?

I have created a dataloader whose length is 50000. When I calculate its length it prints out 50000.

class MyDataLoader(torch.utils.data.Dataset):
    def __init__(self, data_size=50000):
        self.data_size = data_size

    def __len__(self) -> int:
        return self.data_size

    def __getitem__(self, idx) -> t.Tuple[torch.Tensor, torch.Tensor]:
        image, label = my_function()#(has_star=True)
        return image[None], label
dl = MyDataLoader()
print(len(dl))
50000

However, when I iterate over it, it goes forever like the following:

for j, i in enumerate(dl):
  if j%10000 == 0:
    print(j)
10000
20000
30000
40000
50000
60000
...

How is that possible?

Upvotes: 2

Views: 1444

Answers (1)

hkchengrex
hkchengrex

Reputation: 4826

You have created a Dataset, not a Dataloader.

This should work:

import torch
from torch.utils.data import DataLoader

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data_size=50000):
        self.data_size = data_size

    def __len__(self) -> int:
        return self.data_size

    def __getitem__(self, idx):
        # print(idx)
        return idx

dataset = MyDataset()
# Assume a default batch size of 1
dl = DataLoader(dataset)
print(len(dl))

for j, i in enumerate(dl):
  if j%10000 == 0:
    print(j)

# And with a different batch size:
dl = DataLoader(dataset, batch_size=2)
print(len(dl))

for j, i in enumerate(dl):
  if j%10000 == 0:
    print(j)

Note how len(dl) changes when the batch size changes.

Upvotes: 4

Related Questions