Hannah
Hannah

Reputation: 79

Pytorch: How to get the first N item from dataloader

There are 3000 pictures in my list, but I only want the first N of them, like 1000, for training. I wonder how can I achieve this by changing the loop code:

for (image, label) in enumerate(train_loader):

Upvotes: 4

Views: 12836

Answers (2)

Ka Wa Yip
Ka Wa Yip

Reputation: 2983

To get the first N item from train_loader, one can call the __iter__() method of the dataloader, go through each item one by one through __next__(), and wrap it in a for loop:

N = 1000    
dataiter = iter(train_loader)

image_list = []
label_list = []
#assume batch size equal to 1, otherwise divide N by batch size
for i in range(0, N): 
  image, label = next(dataiter)
  image_list.append(image)
  label_list.append(label)

Upvotes: 0

DerekG
DerekG

Reputation: 3938

for (image, label) in list(enumerate(train_loader))[:1000]:

This is not a good way to partition training and validation data though. First, the dataloader class supports lazy loading (examples are not loaded into memory until needed) whereas casting as a list will require all data to be loaded into memory, likely triggering an out-of-memory error. Second, this may not always return the same 1000 elements if the dataloader has shuffling. In general, the dataloader class does not support indexing so is not really suitable for selecting a specific subset of our dataset. Casting as a list works around this but at the expense of the useful attributes of the dataloader class.

Best practice is to use a separate data.dataset object for the training and validation partitions, or at least to partition the data in the dataset rather than relying on stopping the training after the first 1000 examples. Then, create a separate dataloader for the training partition and validation partition.

Upvotes: 3

Related Questions