JawguyChooser
JawguyChooser

Reputation: 1926

Python: calling list() on generator object produces incorrect result

I was looking at the accepted solution to this question, which provides a Python implementation of an algorithm for producing unique permutations in lexicographic order. I have a somewhat shortened implementation:

def permutations(seq):
    seq = sorted(seq)
    while True:
        yield seq
        k = l = None
        for k in range(len(seq) - 1):
            if seq[k] < seq[k + 1]:
                l = k + 1
                break
        else:
            return

        (seq[k], seq[l]) = (seq[l], seq[k])
        seq[k + 1:] = seq[-1:k:-1]

What's really strange for me is that if I call list on the output of this function, I get wrong results. However, if I iterate over the results of this function one at a time, I get the expected results.

>>> list(permutations((1,2,1)))
[[2, 1, 1], [2, 1, 1], [2, 1, 1]]
>>> for p in permutations((1,2,1)):
...   print(p)
... 
[1, 1, 2]
[1, 2, 1]
[2, 1, 1]

^^^What the?! Another example:

>>> list(permutations((1,2,3)))
[[3, 2, 1], [3, 2, 1], [3, 2, 1], [3, 2, 1]]
>>> for p in permutations((1,2,3)):
...   print(p)
... 
[1, 2, 3]
[2, 3, 1]
[3, 1, 2]
[3, 2, 1]

And list comprehension also yields the incorrect values:

>>> [p for p in permutations((1,2,3))]
[[3, 2, 1], [3, 2, 1], [3, 2, 1], [3, 2, 1]]

I have no idea what's going on here! I've not seen this before. I can write other functions that use generators and I don't run into this:

>>> def seq(n):
...   for i in range(n):
...     yield i
... 
>>> list(seq(5))
[0, 1, 2, 3, 4]

What's going on in my example above that causes this?

Upvotes: 3

Views: 1100

Answers (1)

juanpa.arrivillaga
juanpa.arrivillaga

Reputation: 95883

You modify seq within the generator, after you've yielded it. You keep yielding the same object, and modifying it.

    (seq[k], seq[l]) = (seq[l], seq[k]) # this mutates seq
    seq[k + 1:] = seq[-1:k:-1] # this mutates seq

Note, your list contains the same object multiple times:

In [2]: ps = list(permutations((1,2,1)))

In [3]: ps
Out[3]: [[2, 1, 1], [2, 1, 1], [2, 1, 1]]

In [4]: [hex(id(p)) for p in ps]
Out[4]: ['0x105cb3b48', '0x105cb3b48', '0x105cb3b48']

So, try yielding a copy:

def permutations(seq):
    seq = sorted(seq)
    while True:
        yield seq.copy()
        k = None
        l = None
        for k in range(len(seq) - 1):
            if seq[k] < seq[k + 1]:
                l = k + 1
                break
        else:
            return

        (seq[k], seq[l]) = (seq[l], seq[k])
        seq[k + 1:] = seq[-1:k:-1]

And, voila:

In [5]: def permutations(seq):
   ...:     seq = sorted(seq)
   ...:     while True:
   ...:         yield seq.copy()
   ...:         k = None
   ...:         l = None
   ...:         for k in range(len(seq) - 1):
   ...:             if seq[k] < seq[k + 1]:
   ...:                 l = k + 1
   ...:                 break
   ...:         else:
   ...:             return
   ...:
   ...:         (seq[k], seq[l]) = (seq[l], seq[k])
   ...:         seq[k + 1:] = seq[-1:k:-1]
   ...:

In [6]: ps = list(permutations((1,2,1)))

In [7]: ps
Out[7]: [[1, 1, 2], [1, 2, 1], [2, 1, 1]]

As to why printing in a for-loop doesn't reveal this behavior, it's because at that moment in the iteration seq has the "correct" value, so consider:

In [10]: result = []
    ...: for i, x in enumerate(permutations((1,2,1))):
    ...:     print("iteration ", i)
    ...:     print(x)
    ...:     result.append(x)
    ...:     print(result)
    ...:
iteration  0
[1, 1, 2]
[[1, 1, 2]]
iteration  1
[1, 2, 1]
[[1, 2, 1], [1, 2, 1]]
iteration  2
[2, 1, 1]
[[2, 1, 1], [2, 1, 1], [2, 1, 1]]

Upvotes: 10

Related Questions