Reputation: 2217
Is there a way in Python to generate multiple outputs at the same time. In particular I want something like:
my_gen =(i for i in range(10))
and say I have a parameter batch_size = 3
. I would want my generator to output:
my_gen.next()
0,1,2
my_gen.next()
3,4,5
my_gen.next()
6,7,8
my_gen.next()
9,10
where on the last command, it only yields two numbers because there are only two numbers left even though the batch_size
is 3.
Upvotes: 3
Views: 1540
Reputation: 66
If you expect the iterator/generator to have a multiple of the batch size elements you can simply do:
gen = iter(range(12))
for x, y, z in iter(lambda: [next(gen) for _ in range(3)], 1):
print(x, y, z)
If not; this should suit your needs:
gen = iter(range(11))
for t in iter(lambda: [next(gen, None) for _ in range(3)], [None]*3):
print(*[x for x in t if x is not None])
Pros:
Upvotes: 0
Reputation: 59274
IMO, no need for any libraries. You may just define your own batch generator
def batch_iter(batch_size, iter_):
yield [next(iter_) for _ in range(batch_size)]
and just
next(batch_iter(batch_size, x))
A iteration-safe version would be
def batch_iter(batch_size, iter_):
r = []
for _ in range(b):
val = next(iter_, None)
if val is not None: r.append(val)
yield r
Of course you may yield tuple(r)
instead of just r
if you need tuple values. You may also add an else
clause and break
the loop since once val
is None
, there are no more values to iterate
Upvotes: 2
Reputation: 7887
You can use list comprehension with the generator:
batch_size, max_size = 3, 10
my_gen = ([x for x in range(i, i + batch_size) if x <= max_size] for i in range(0, max_size, batch_size))
for x in my_gen:
print(x)
Upvotes: 1
Reputation: 7211
On the itertools
page there is a grouper function provided:
def grouper(iterable, n, fillvalue=None):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
And with that you can make your generator and wrap it with a grouper:
for my_tuple in grouper(my_gen, 3):
print([x for x in my_tuple if x is not None])
Upvotes: 5