Reputation: 1417
I am using a code that trains neural networks. The code uses the DataLoader of PyTorch to load the data for every iteration. The code looks as follows
for step, data in enumerate(dataloader, 0):
............................................................
output = neuralnetwork_model(data)
.............................................................
Here the step is an integer that gives values 0, 1, 2, 3, ....... and data gives a batch of samples at each step. The code passes corresponding batches to the neural network at each step.
I need to just access the data of step n+1 at step n. I need something like this
for step, data in enumerate(dataloader, 0):
............................................................
output = neuralnetwork_model(data)
access = data_of_next_step
.............................................................
How can I achieve this?
Upvotes: 1
Views: 582
Reputation: 40748
It seems to be handier to perform such manipulation at the iteration level rather than having to change the data loaders implementation. Looking at Iterate over n
successive elements with overlap you can achieve this using itertools.tee
:
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = tee(iterable)
next(b, None)
return zip(a, b)
Therefore you simply have to iterate over your wrapped data loader with:
>>> for batch1, batch2 pairwise(dataloader)
... # batch1 is current batch
... # batch2 is batch of following step
Upvotes: 1