Aakanksha W.S
Aakanksha W.S

Reputation: 301

How to get entire dataset from dataloader in PyTorch

How to load entire dataset from the DataLoader? I am getting only one batch of dataset.

This is my code

dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=64)
images, labels = next(iter(dataloader))

Upvotes: 20

Views: 57735

Answers (3)

asymptote
asymptote

Reputation: 1402

You can set batch_size = len(dataset). Beware, this might require a lot of memory depending upon your dataset.

Upvotes: 14

Jean B.
Jean B.

Reputation: 147

Another option would be to get the entire dataset directly, without using the dataloader, like so :

images, labels = dataset[:]

Upvotes: 13

Florian Blume
Florian Blume

Reputation: 3345

I'm not sure if you want to use the dataset somewhere else than network training (like inspecting the images for example) or want to iterate over the batches during training.

Iterating through the dataset

Either follow Usman Ali's answer (which might overflow) your memory or you could do

for i in range(len(dataset)): # or i, image in enumerate(dataset)
    images, labels = dataset[i] # or whatever your dataset returns

You are able to write dataset[i] because you implemented __len__ and __getitem__ in your Dataset class (as long as it's a subclass of the Pytorch Dataset class).

Getting all batches from the dataloader

The way I understand your question is that you want to retrieve all batches to train the network with. You should understand that iter gives you an iterator of the dataloader (if you're not familiar with the concept of iterators see the wikipedia entry). next tells the iterator to give you the next item.

So, in contrast to an iterator traversing a list the dataloader always returns a next item. List iterators stop at some point. I assume that you have something like a number of epochs and a number of steps per epoch. Then your code would look like this

for i in range(epochs):
    # some code
    for j in range(steps_per_epoch):
        images, labels = next(iter(dataloader))
        prediction = net(images)
        loss = net.loss(prediction, labels)
        ...

Be careful with next(iter(dataloader)). If you wanted to iterate through a list this might also work because Python caches objects but you could end up with a new iterator every time that starts at index 0 again. To avoid this take out the iterator to the top, like so:

iterator = iter(dataloader)
for i in range(epochs):
    for j in range(steps_per_epoch):
        images, labels = next(iterator)

Upvotes: 10

Related Questions