Reputation: 8778
I am writing a function that should output all k-way partitions of a list A. This problem is clearly recursive, and the implementation should be straightforward:
def gen_partition_k_group( A, k):
#
if len(A) == 0 :
# EDITED FOLLOWING SUGGESTION
yield [ [] for _ in xrange(k) ]
#
else :
# all k-partitions of the list of size N-1
for ss in gen_partition_k_group(A[:-1], k) :
assert( sum( len(gg) for gg in ss ) == len(A) -1 )
for ii in xrange(k) :
tt = list(ss)
print tt
tt[ ii ].append( A[ -1 ] )
print tt
assert( sum( len(gg) for gg in tt ) == len(A) )
yield tt
A = range(3)
k = 2
[ xx for xx in gen_partition_k_group( A, k) ]
Output
AssertionError:
[[], []]
[[0], [0]]
I don't understand the output. It should be [[0], []]
instead of [[0], [0]]
. What am I missing?
NB: I know how to write a different function without append
which outputs the correct result. Iterator over all partitions into k groups? (first answer)
What I don't understand is the behaviour of this particular function.
Upvotes: 0
Views: 343
Reputation: 8778
Ok so the problem was the line tt = list(ss)
which is only making shallow copies of the list. Using tt = copy.deepcopy(ss)
solved the problem.
Upvotes: 0
Reputation: 353059
One problem is probably that [ [] ] * k
doesn't do what you think it does. That doesn't make k
empty lists, it makes one new empty list and k
references to it. For example:
>>> [[]]*3
[[], [], []]
>>> a = [[]]*3
>>> a
[[], [], []]
>>> a[0].append(1)
>>> a
[[1], [1], [1]]
>>> id(a[0]), id(a[1]), id(a[2])
(25245744, 25245744, 25245744)
>>> a[0] is a[1]
True
>>> a[0] is a[2]
True
To make multiple new lists, you could do something like
>>> a = [[] for _ in xrange(3)]
>>> a
[[], [], []]
>>> id(a[0]), id(a[1]), id(a[2])
(41563560, 41564064, 41563056)
I don't think this by itself will fix your program -- I still get an assert
tripping -- but it should help.
Upvotes: 1