muddyfish
muddyfish

Reputation: 3650

Python product of infinite generators

I'm trying to get the product of 2 infinite generators but the product function in itertools doesn't allow this sort of behaviour.

Example behaviour:

from itertools import *
i = count(1)
j = count(1)
x = product(i, j)

[Killed]

What I want:

x = product(i, j)

((0,0), (0,1), (1,0), (1,1) ...)

It doesn't matter in which order the combinations get returned as long as given infinite time, all combinations will be eventually generated. This means that given a combination of elements, there must be a finite index in the returned generator with that combination.

Upvotes: 4

Views: 2188

Answers (4)

zoravur
zoravur

Reputation: 82

Here is my solution. It uses a modified zigzag algorithm, where the next element after (x, y) is (x-1, y+1), except at the end of a diagonal.

The complexity of this algorithm is O(n+m), where mn is the cardinality of the set being generated (in the infinite case, it requires O(n+m) memory to generate the (mn)th element).

def infinite_cart_prod(i, j):
    i_reversed = []
    j_forward = []
    n = 0
    while True:
        n += 1
        i_reversed.append(next(i))
        j_forward.append(next(j))
        for k in range(n):
            yield (i_reversed[n-1-k], j_forward[k])

# Uncomment to print cart prod infinitely
# from itertools import *
# gen = infinite_cart_prod(count(1), count(1))
# while True:
#     print(next(gen))

Upvotes: 0

enrico.bacis
enrico.bacis

Reputation: 31504

tl;dr

The code presented below is now included in the package infinite on PyPI. So now you can actually just pip install infinite before running this:

from itertools import count
from infinite import product

for x, y in product(count(0), count(0)):
    print(x, y)
    if (x, y) == (3, 3):
        break

The zig-zag algorithm

What is needed is a way to iterate the pairs of numbers so that looking for a specific pair (made of finite numbers) can be done in finite time. A way to go this is the zig-zag scanning algorithm.

zig-zag scanning algorithm

In order to do it you need to cache previous values, so I created a class GenCacher to cache previously extracted values:

class GenCacher:
    def __init__(self, generator):
        self._g = generator
        self._cache = []

    def __getitem__(self, idx):
        while len(self._cache) <= idx:
            self._cache.append(next(self._g))
        return self._cache[idx]

After that you just need to implement the zig-zag algorithm:

def product(gen1, gen2):
    gc1 = GenCacher(gen1)
    gc2 = GenCacher(gen2)
    idx1 = idx2 = 0
    moving_up = True

    while True:
        yield (gc1[idx1], gc2[idx2])

        if moving_up and idx1 == 0:
            idx2 += 1
            moving_up = False
        elif not moving_up and idx2 == 0:
            idx1 += 1
            moving_up = True
        elif moving_up:
            idx1, idx2 = idx1 - 1, idx2 + 1
        else:
            idx1, idx2 = idx1 + 1, idx2 - 1

Example

from itertools import count

for x, y in product(count(0), count(0)):
    print(x, y)
    if x == 2 and y == 2:
        break

This produces the following output:

0 0
0 1
1 0
2 0
1 1
0 2
0 3
1 2
2 1
3 0
4 0
3 1
2 2

Extend the solution to more than 2 generators

We can edit the solution a bit to make it work even for multiple generators. The basic idea is:

  1. iterate over the distance from (0,0) (the sum of the indexes). (0,0) is the only one with distance 0, (1,0) and (0,1) are at distance 1, etc.

  2. generate all the tuples of indexes for that distance

  3. extract the corresponding element

We still need the GenCacher class, but the code becomes:

def summations(sumTo, n=2):
    if n == 1:
        yield (sumTo,)
    else:
        for head in xrange(sumTo + 1):
            for tail in summations(sumTo - head, n - 1):
                yield (head,) + tail

def product(*gens):
    gens = map(GenCacher, gens)

    for dist in count(0):
        for idxs in summations(dist, len(gens)):
            yield tuple(gen[idx] for gen, idx in zip(gens, idxs))

Upvotes: 9

muddyfish
muddyfish

Reputation: 3650

 def product(a, b):
     a, a_copy = itertools.tee(a, 2)
     b, b_copy = itertools.tee(b, 2)
     yield (next(a_copy), next(b_copy))
     size = 1
     while 1:
         next_a = next(a_copy)
         next_b = next(b_copy)
         a, new_a = itertools.tee(a, 2)
         b, new_b = itertools.tee(b, 2)
         yield from ((next(new_a), next_b) for i in range(size))
         yield from ((next_a, next(new_b)) for i in range(size))
         yield (next_a, next_b)
         size += 1

A homebrew solution with itertools.tee. This uses lots of memory as intermediate states are stored in tee

This effectively returns the sides of an ever expanding square:

0 1 4 9 
2 3 5 a
6 7 8 b
c d e f

Given infinite time and infinite memory, this implementation should return all possible products.

Upvotes: 1

RemcoGerlich
RemcoGerlich

Reputation: 31260

No matter how you do it, memory will grow a lot, as every value from each iterator will occur an infinite number of times after the first time, so it has to be kept around in some growing variable.

So something like this may work:

def product(i, j):
    """Generate Cartesian product i x j; potentially uses a lot of memory."""
    earlier_values_i = []
    earlier_values_j = []

    # If either of these fails, that sequence is empty, and so is the
    # expected result. So it is correct that StopIteration is raised,
    # no need to do anything.
    next_i = next(i)
    next_j = next(j)
    found_i = found_j = True

    while True:
        if found_i and found_j:
            yield (next_i, next_j)
        elif not found_i and not found_j:
            break  # Both sequences empty

        if found_i: 
            for jj in earlier_values_j:
                yield (next_i, jj)
        if found_j:
            for ii in earlier_values_i:
                yield (ii, next_j)

        if found_i:
            earlier_values_i.append(next_i)
        if found_j:
            earlier_values_j.append(next_j)

        try:
            next_i = next(i)
            found_i = True
        except StopIteration:
            found_i = False

        try:
            next_j = next(j)
            found_j = True
        except StopIteration:
            found_j = False

This was so simple in my head but it looks horribly complicated after typing it out, there must be some simpler way. But I think it will work.

Upvotes: 0

Related Questions