Daniel Arteaga
Daniel Arteaga

Reputation: 477

Best way to convert generator into iterator class

Consider the following dummy example:

def common_divisors_generator(n, m):

    # Init code
    factors_n = [i for i in range(1, n + 1) if n%i == 0]
    factors_m = [i for i in range(1, m + 1) if m%i == 0]

    # Iterative code
    for fn in factors_n:
        for fm in factors_m:
            if fn == fm:
                yield fn

# The next line is fast because no code is executed yet
cdg = common_divisors_generator(1537745, 373625435)
# Next line is slow because init code is executed on first iteration call
for g in cdg:
    print(g)

The init code, which takes a long time to compute, is executed once the generator has been iterated for the first time (as opposed to when the generator it is initialized). I would prefer that the init code it is executed as the generator is initialized.

For this purpose I convert the generator into an iterator class as follows:

class CommonDivisorsIterator(object):

    def __init__(self, n, m):
        # Init code
        self.factors_n = [i for i in range(1, n + 1) if n%i == 0]
        self.factors_m = [i for i in range(1, m + 1) if m%i == 0]

    def __iter__(self):
        return self

    def __next__(self):
        # Some Pythonic implementation of the iterative code above
        # ...
        return next_common_divisor

All ways I can think of implementing the __next__ method above are very cumbersome as compared to the simplicity of the iterative code in the generator with the yield keyword.

What would be the most Pythonic way of implementing the __next__ method in the iterator class?

Alternatively, how can I modify the the generator so that the init code is executed at init time?

Upvotes: 9

Views: 6339

Answers (2)

Boichee
Boichee

Reputation: 21

The other answers to your specific question are dead on, but if the concern is mainly the initialization time, I'd suggest that in general, the code could be optimized quite a bit.

For example, in the "iterator" portion, you are comparing factors for n to factors for m. This nested loop will have O(n*m) runtime. I know why you're doing this of course: One value could have [2, 3, 5] as it's factors, while the other has just [2, 5], and you don't want to miss the 5 because the 3 doesn't match. But there are far faster ways to accomplish this and I've written up a few of them below.

First, you'll notice I've changed how the factors are found a bit. You don't need to look beyond the square root of a number to find all of its factors because each factor of a number has a "complement" that is also its factor. So, for example, every number x has the factor 1, with x as the complement. Thus, once you know that y is a factor of x you also know that x / y is a factor of x. As such, you can find all the factors of a number x in O(sqrt(x)) time. And this seriously speeds things up for large x.

Second, rather than storing the factors in lists, you should store them in sets. There's no risk, with factors, of duplicates, so a set is ideal. They can be ordered, and they have O(1) lookup time (like hashmaps). This way, you only need to iterate over the factors of x to find out if they are common with the factors of y. This alone changes your runtime from O(n*m) to O(min(n, m)) where n and m are the size of the factor sets.

Finally, if I were implementing this, I'd likely do it lazily, only finding new common factors when needed. This eliminates the cost of an initialization step entirely.

I've included implementations of both approaches below.

from math import sqrt
def find_factors(x: int) -> int:
    """Here's the deal with the sqrt thing. For any value x with factor a, there's a complement b.
    i.e., if x = 8, then we know 2 * 4 = 8. 2 is the first factor we'll find counting from 1 upward,
    but once we find 2, we know that 4 (the complement) is also a factor of 8. The largest
    unique factor whose complement can't be known earlier is the sqrt(x). So to save time,
    we just put aside the complements until later. If we didn't do this, we'd have to iterate
    all the way to x (which could be quite large) even though we'd already know that there are no
    factors between (for example), x / 2 and x.
    This changes the runtime of finding the initial "smaller" factors to log(sqrt(x)), and the
    total time to O(z) where z is the number of factors in x and z will always be smaller than x
    for all x > 2.

    Args:
        x (int): The value to factor
    """
    complements = []
    for v in range(1, int(sqrt(x))+1):
        if x % v == 0:
            complements.append(x // v)
            yield v
    
    for v in reversed(complements):
        yield v
   
def common_factors_greedy(n, m):
    """
    This will run in O(min(N, M)) time instead of O(min(N^2, M^2)) time.
    Note that N, M are that size of the factor set, not the size of n or m.
    """
    
    # I'd recommend creating these as sets to begin with so you don't have to waste cycles converting them
    factors_n, factors_m = set(find_factors(n)), set(find_factors(m))
    common_factors = factors_n & factors_m
    for c in common_factors:
        yield c

def common_factors_lazy(n, m):
    """
     Generates common factors of n and m lazily, which means there's no initialization cost up front. You only use
     compute when you actually look for the next common factor. Overall, might be overkill as using the approach
     I wrote up for finding factors, its pretty fast even for large n or m. But still worth thinking about
     for other kinds of problems.
    """
    # Note: Factors_n/m are not lists of factors. They are generator objects. They don't actually "know"
    # anything about the factors of n or m until you call next(factors_n) or next(factors_m).
    factors_n, factors_m = find_factors(n), find_factors(m)
    x, y = next(factors_n), next(factors_m)
    x_empty = y_empty = False
    while not (x_empty and y_empty) and not ((x_empty and x < y) or (y_empty and y < x)):
        if x == y:
            yield x
            try:
                x = next(factors_n)
                y = next(factors_m)
            except StopIteration:
                return
        elif x > y:
            try:    
                y = next(factors_m)
            except StopIteration:
                y_empty = True
        else:
            try:
                x = next(factors_n)
            except StopIteration:
                x_empty = True


def main():
    N = 1_220_142
    M = 837_462
    
    for fact in find_factors(N):
        print(f'Factor of N: {fact}')
    for fact in find_factors(M):
        print(f'Factor of M: {fact}')
    
    for com in common_factors_greedy(N, M):
        print(f'Greedy factor: {com}')
    
    for com in common_factors_lazy(N, M):
        print(f'Lazy factor: {com}')

Upvotes: 2

Aran-Fey
Aran-Fey

Reputation: 43326

In both cases (whether you use a function or a class), the solution is to split the implementation into two functions: a setup function and a generator function.

Using yield in a function turns it into a generator function, which means that it returns a generator when it's called. But even without using yield, nothing's preventing you from creating a generator and returning it, like so:

def common_divisors_generator(n, m):
    factors_n = [i for i in range(1, n + 1) if n%i == 0]
    factors_m = [i for i in range(1, m + 1) if m%i == 0]

    def gen():
        for fn in factors_n:
            for fm in factors_m:
                if fn == fm:
                    yield fn

    return gen()

And if you're using a class, there's no need to implement a __next__ method. You can just use yield in the __iter__ method:

class CommonDivisorsIterator(object):
    def __init__(self, n, m):
        self.factors_n = [i for i in range(1, n + 1) if n%i == 0]
        self.factors_m = [i for i in range(1, m + 1) if m%i == 0]

    def __iter__(self):
        for fn in self.factors_n:
            for fm in self.factors_m:
                if fn == fm:
                    yield fn

Upvotes: 7

Related Questions