titipata
titipata

Reputation: 5389

Circular batch generation from list

Basically, I would like to create an infinite generator from the given list l with some batch size of batch_size. For example, if I have list of l = [1, 2, 3, 4, 5] and batch_size = 2, I'd like to generate infinite loop of [1, 2], [3, 4], [5, 1], [2, 3], ... (similar to itertool.circular with additional batch size)

My current approach is the following which doesn't give the right solution yet because at the end I just pad the first element of list when I reach the end:

l = [1, 2, 3, 4, 5]

def generator(l, batch_size=2):
    while 1:
        for i in range(0, len(l), batch_size):
            batch = l[i:(i + batch_size)]
            if len(batch) < batch_size:
                batch.extend(l[0: batch_size - len(batch)])
            yield batch

>> gen = generator(l, batch_size=2)
>> next(gen) # [1, 2]
>> next(gen) # [3, 4]
>> next(gen) # [5, 1]
>> next(gen) # [1, 2] << want [2, 3] instead

Is there a way to do that in a circular way?

Upvotes: 1

Views: 483

Answers (2)

user2390182
user2390182

Reputation: 73450

This should work:

def generator(l, batch_size=2):
    gen = iter(itertools.cycle(l))
    while 1:
        yield [next(gen) for _ in range(batch_size)]

gen = generator(l, batch_size=2)

Upvotes: 3

juanpa.arrivillaga
juanpa.arrivillaga

Reputation: 95873

Yes, you basically want a combination of "take" and cycle:

>>> def circle_batch(iterable, batchsize):
...     it = itertools.cycle(iterable)
...     while True:
...         yield list(itertools.islice(it, batchsize))
...
>>> l = [1, 2, 3, 4, 5]
>>> c = circle_batch(l, 2)
>>> next(c)
[1, 2]
>>> next(c)
[3, 4]
>>> next(c)
[5, 1]
>>> next(c)
[2, 3]
>>> next(c)
[4, 5]

From the recipes in the docs you'll see that take is a basic tool, so using that:

>>> def take(n, iterable):
...     "Return first n items of the iterable as a list"
...     return list(islice(iterable, n))
...
>>> def cycle_batch(iterable, batchsize):
...     it = itertools.cycle(iterable)
...     while True:
...         return take(batchsize, it)
...
>>> l = [1, 2, 3, 4, 5]
>>> c = circle_batch(l, 2)
>>> next(c)
[1, 2]
>>> next(c)
[3, 4]
>>> next(c)
[5, 1]
>>> next(c)
[2, 3]
>>> next(c)
[4, 5]
>>> next(c)
[1, 2]

Upvotes: 4

Related Questions