Zarathustra
Zarathustra

Reputation: 411

Select specific items from generator

Given a generator object, I want to be able to iterate only over the elements with defined indices. However, I have experienced that when the generator yields "heavy" items (memory-wise), the loading of each item takes a considerable time compared to the case when iteration is performed on every element in the original order. Here is a minimal example, where in my real situation, I have no direct access to the generator function (i.e. mygen cannot be modified).

import itertools
import numpy as np 

def mygen(n):
    
    for k in range(n):
        yield k 
        

mygenerator=mygen(10)

indices_ls=[0,5,8]
mask      =np.zeros(10,dtype=int)
mask[indices_ls]=1

compressed_generator=itertools.compress(mygenerator, mask)
for i in compressed_generator:
    print(i) 

This minimal code does the trick, because only selected elements are printed. But, is there a way to do this without using the compressed_generator, i.e. by iterating directly over the original generator, but getting ONLY the specified indices? Recall that the generator function (mygen)cannot be modified.

Upvotes: 1

Views: 621

Answers (2)

Alain T.
Alain T.

Reputation: 42133

If the generator takes a long time loading items, there is nothing much you can do from the outside to improve performance.

If you want to process specific indexes, iterate through the list of indices and skip over unwanted values. You can place that in a function of your own to make it easier to use:

def xslice(gen,indices):
    skipFrom = 0                    # base for skipping unwanted 
    for i in indices:
        for _ in range(i-skipFrom): # skip to previous iteration
            next(gen)
        yield next(gen)             # next after skipping is wanted
        skipFrom = i+1
        
for i in xslice(mygen(10),[0,5,8]):
    print(i)

0
5
8

This will also stop iterating after processing the highest indice which may save a bit of time

If you have access to the source code of that generator function you may find that it is designed to allow using send() to control its progression. This will depend on that generator's implementation but there could be a way to work around the problem with .send() if you are lucky.

For example, here's a variant of the toy generator that accepts a sent value to reposition the progression (the telltale sign of .send() support is that it uses the value returned from the yield statement):

def mygen(n):
    k = 0
    while k<n:
        skipTo = yield k
        k = k+1 if skipTo is None else skipTo

You could leverage this to skip to the desired indices (although the first one is tricky because you can't send() before the first iteration).

gen     = mygen(10)
indices = [0,5,8]
for i in indices:
    if i==indices[0] and i>0:
        next(gen)
    if i>0: gen.send(i-1)
    value=next(gen)
    print(value)

Upvotes: 2

Andrej Kesely
Andrej Kesely

Reputation: 195553

One solution might be to "patch" the function the generator is using (in this case range(), but in your original code you need to patch your functions):

from contextlib import contextmanager


def mygen(n):
    for k in range(n):
        yield k

@contextmanager
def patch_range():
    def my_range(x):
        yield 0
        yield 5
        yield 8

    try:
        original_range = __builtins__.range
        __builtins__.range = my_range
        yield
    finally:
        __builtins__.range = original_range

with patch_range():
    for i in mygen(10):
        print(i)

Prints:

0
5
8

Upvotes: 1

Related Questions