Anthony Labarre
Anthony Labarre

Reputation: 2794

Generating all subsets with pairing constraints

I need to generate all k-subsets of an n-set, with the additional constraint that some pairs of elements have to be selected either together or not at all. To model this constraint, I thought about explicitly pairing those elements as 2-tuples and keeping the others as 1-tuples. So for instance, let's assume I need to select all 3-element subsets of {1, 2, 3, 4, 5}, with the additional constraints that elements 3 and 4 must be selected together. Then my new set is:

{(1,), (2,), (3, 4), (5,)}

and the function I want to write would need to generate:

{1, 2, 5}, {1, 3, 4}, {2, 3, 4}, {3, 4, 5}.

Is there a simple way to use itertools (or possibly other python modules I might not know) to obtain this result? I do not care about the order in which I receive those subsets.

In case this simplifies things: an element cannot be paired with more than one other element (so (3, 5) for instance could not have appeared as an additional constraint in my example).

Upvotes: 0

Views: 219

Answers (1)

Alex Hall
Alex Hall

Reputation: 36033

Solution:

from itertools import combinations, chain

def faster(pairs, others, k):
    for npairs in range(k // 2 + 1):
        for pairs_comb in combinations(pairs, npairs):
            for others_comb in combinations(others, k - npairs * 2):
                yield chain(others_comb, *pairs_comb)

Explanation:

Go through all the possibilities for the number of pairs in the outcome. For example if k = 5 then there can either be no pairs and 5 unconstrained elements (others), or 1 pair and 3 other elements, or 2 pairs and 1 other element. Then all the combinations of pairs and others can be generated independently and combined.

Test:

def brute_force(pairs, others, k):
    return [c for c in combinations(chain(others, *pairs), k)
            if all((p1 in c) == (p2 in c) for p1, p2 in pairs)]

def normalise(combs):
    return sorted(map(sorted, combs))

args = ([(3, 4), (1, 2), (6, 7)], [5, 8, 9, 10, 11], 4)
assert normalise(brute_force(*args)) == normalise(faster(*args))

print(normalise(faster(*args)))

Upvotes: 1

Related Questions