Matthew Drury
Matthew Drury

Reputation: 1095

The product of a generator with itself

I need to iterate over the product of a generator with itself, excluding the diagonal. I'm attempting to use itertools.tee to consume the same generator twice

def pairs_exclude_diagonal(it):
    i1, i2 = itertools.tee(it, 2)
    for x in i1:
        for y in i2:
            if x != y:
                yield (x, y)

This does not work

In [1]: for (x, y) in pairs_exclude_diagonal(range(3)):
   ...:     print(x, y)
0 1
0 2

The documentation for tee states:

Return n independent iterators from a single iterable.

What's the proper way to do this?

(I'm using python3.6.1)

Upvotes: 6

Views: 1030

Answers (2)

SethMMorton
SethMMorton

Reputation: 48745

It looks like you want to use itertools.permutations.

In [1]: import itertools

In [2]: for x, y in itertools.permutations(range(3), 2):
   ...:     print(x, y)
   ...:     
0 1
0 2
1 0
1 2
2 0
2 1

If you really want to do it using tee, you will have to turn the second iterable into a list so that it is not exhausted on the second time through the outer for loop:

In [14]: def pairs_exclude_diagonal(it):
    ...:     i1, i2 = itertools.tee(it, 2)
    ...:     l2 = list(i2)
    ...:     for x in i1:
    ...:         for y in l2:
    ...:             if x != y:
    ...:                 yield (x, y)
    ...:                 

In [15]: for (x, y) in pairs_exclude_diagonal(range(3)):
    ...:     print(x, y)
    ...:     
0 1
0 2
1 0
1 2
2 0
2 1

Note that this is pretty pointless, since calling list on an iterator loads it into memory and defeats the purpose of having an iterator in the first place.

Upvotes: 9

Blckknght
Blckknght

Reputation: 104722

The issue is that you're trying to reuse the i2 iterator. After it's been iterated once, it's exhausted and so you won't be able to iterate on it again. When you try, it yields nothing.

I think rather than tee (which is not very efficient for this purpose anyway), you should use itertools.product to generate all pairs (before filtering out the ones you want to skip):

def pairs_exclude_diagonal(it):
    for x, y in itertools.product(it, repeat=2):
        if x != y:
            yield (x, y)

Upvotes: 3

Related Questions