Reputation: 301
How to load entire dataset from the DataLoader? I am getting only one batch of dataset.
This is my code
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=64)
images, labels = next(iter(dataloader))
Upvotes: 20
Views: 57735
Reputation: 1402
You can set batch_size = len(dataset)
. Beware, this might require a lot of memory depending upon your dataset.
Upvotes: 14
Reputation: 147
Another option would be to get the entire dataset directly, without using the dataloader, like so :
images, labels = dataset[:]
Upvotes: 13
Reputation: 3345
I'm not sure if you want to use the dataset somewhere else than network training (like inspecting the images for example) or want to iterate over the batches during training.
Iterating through the dataset
Either follow Usman Ali's answer (which might overflow) your memory or you could do
for i in range(len(dataset)): # or i, image in enumerate(dataset)
images, labels = dataset[i] # or whatever your dataset returns
You are able to write dataset[i]
because you implemented __len__
and __getitem__
in your Dataset
class (as long as it's a subclass of the Pytorch Dataset
class).
Getting all batches from the dataloader
The way I understand your question is that you want to retrieve all batches to train the network with. You should understand that iter
gives you an iterator of the dataloader (if you're not familiar with the concept of iterators see the wikipedia entry). next
tells the iterator to give you the next item.
So, in contrast to an iterator traversing a list the dataloader always returns a next item. List iterators stop at some point. I assume that you have something like a number of epochs and a number of steps per epoch. Then your code would look like this
for i in range(epochs):
# some code
for j in range(steps_per_epoch):
images, labels = next(iter(dataloader))
prediction = net(images)
loss = net.loss(prediction, labels)
...
Be careful with next(iter(dataloader))
. If you wanted to iterate through a list this might also work because Python caches objects but you could end up with a new iterator every time that starts at index 0 again. To avoid this take out the iterator to the top, like so:
iterator = iter(dataloader)
for i in range(epochs):
for j in range(steps_per_epoch):
images, labels = next(iterator)
Upvotes: 10