Leockl
Leockl

Reputation: 2156

What does next() and iter() do in PyTorch's DataLoader()

I have the following code:

import torch
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader

# Load dataset
df = pd.read_csv(r'../iris.csv')

# Extract features and target
data = df.drop('target',axis=1).values
labels = df['target'].values

# Create tensor dataset
iris = TensorDataset(torch.FloatTensor(data),torch.LongTensor(labels))

# Create random batches
iris_loader = DataLoader(iris, batch_size=105, shuffle=True)

next(iter(iris_loader))

What does next() and iter() do in the above code? I have went through PyTorch's documentation and still can't quite understand what is next() and iter() doing here. Can anyone help in explaining this? Many thanks in advance.

Upvotes: 38

Views: 88942

Answers (2)

eric
eric

Reputation: 8019

The accepted answer is right. I just wanted to give a complementary answer as I got confused about this topic and iterators/iterables.

I initially thought the data loader was an iterator, so thought iter(data_loader) seemed redundant. But data loader is an iterable, not an iterator. Similarly, a list is not an iterator but an iterable. If you try to run next(x) directly on a list x you will get TypeError: 'list' object is not an iterator. To iterate over a list you first have to cast it as an iterator with iter(x): then you can start to iterate over it with next().

The same logic holds for data loaders: they are iterables, not iterators, and you can cast them as iterators using iter(data_loader). Then to step through them you use the next() operation. You could just as easily break it up into multiple steps:

# define data loader (iterable)
iris_loader = DataLoader(iris, batch_size=105, shuffle=True)

# define iterator for use in training
iris_iterator = iter(iris_loader)

# extract batch
data_batch = next(iris_iterator)

Upvotes: 15

ScootCork
ScootCork

Reputation: 3676

These are built-in functions of python, they are used for working with iterables.

Basically iter() calls the __iter__() method on the iris_loader which returns an iterator. next() then calls the __next__() method on that iterator to get the first iteration. Running next() again will get the second item of the iterator, etc.

This logic often happens 'behind the scenes', for example when running a for loop. It calls the __iter__() method on the iterable, and then calls __next__() on the returned iterator until it reaches the end of the iterator. It then raises a stopIteration and the loop stops.

Please see the documentation for further details and some nuances: https://docs.python.org/3/library/functions.html#iter

Upvotes: 43

Related Questions