Diane M
Diane M

Reputation: 1512

How can implement efficiently this subset enumeration problem?

I have a list of numbers Sn = [a, b, c, d, ...] and a set of non-overlapping intervals Si = {I1, I2, I3, ...}. Given that, my problem is to find the list L of subsets of Sn such that the sum of elements in each subset is bound into at least an interval inside Si.

My approach right now is to enumerate all the subsets of Sn and filter them based if they fit an interval. It is correct, but inefficient.

import itertools

Sn = [12, 30, 60, 6, 6]
Si = {(12, 12), (18, 24), (30, 48)}

def enumerate_sets():
    sets = []
    for i in range(len(Sn)):
        for comb in itertools.combinations(Sn, i + 1):
            for interval in Si:
                if interval[0] <= sum(comb) <= interval[1]:
                    sets.append(comb)
                    break
    return sets

print(enumerate_sets())
# [(12,), (30,), (12, 30), (12, 6), (12, 6), (30, 6), (30, 6), (6, 6), (12, 30, 6), (12, 30, 6), (12, 6, 6), (30, 6, 6)]

How can I efficiently implement this subset problem? Answers in python are prefered but any (pseudo)language will do.

Upvotes: 1

Views: 85

Answers (2)

Georgina Skibinski
Georgina Skibinski

Reputation: 13387

Ok, so as for programming, python solution - this will do the trick:

from itertools import combinations
import numpy as np

def getL(Sn:np.ndarray, Si:np.ndarray):
    Si.sort(axis=0)
    Sn.sort()
    for i in range(1, len(Sn)+1):
        for s in combinations(Sn, i):
            su=sum(s)
            if((np.logical_and(Si[:,0]<=su, Si[:,1]>=su)).any()):
                yield s

Sample output:

>>> # to convert your input into numpy:
>>> Sn = [12, 30, 60, 6, 6]
>>> Si = {(12, 12), (18, 24), (30, 48)}
>>> Sn = np.array(Sn)
>>> Si = np.array(list(Si))
>>> print(list(getL(Sn, Si)))

[(12,), (30,), (6, 6), (6, 12), (6, 30), (6, 12), (6, 30), (12, 30), (6, 6, 12), (6, 6, 30), (6, 12, 30), (6, 12, 30)]

Few notes - to save memory - generally using generator should be faster, so you won't just accumulate into list, but once you get something - you return, and forget it. Use numpy - this will speed up browsing through the intervals significantly.

Upvotes: 1

AKX
AKX

Reputation: 168996

Since the sum of the combination is enough to know whether or not it's a valid combination, you can use it as a set key and avoid doing extra work.

enumerate_sets_2 benchmarks to be (more than) 2x faster than the original here:

import timeit
import itertools
import collections

Sn = [12, 30, 60, 6, 6, 92, 443, -8, 112, 96]
Si = {(12, 12), (18, 24), (30, 48)}


def enumerate_sets_orig(Sn, Si):
    sets = []
    for i in range(len(Sn)):
        for comb in itertools.combinations(Sn, i + 1):
            for interval in Si:
                if interval[0] <= sum(comb) <= interval[1]:
                    sets.append(comb)
                    break
    return sets


def enumerate_sets_2_iter(Sn, Si):
    valid_sums = set()
    for i in range(len(Sn)):
        for comb in itertools.combinations(Sn, i + 1):
            comb_sum = sum(comb)
            if comb_sum in valid_sums:
                yield comb
                continue
            for a, b in Si:
                if a <= comb_sum <= b:
                    valid_sums.add(comb_sum)
                    yield comb
                    break



def make_comparable(result):
    # Make an enumerate_sets_* result comparable by sorting and removing duplicates
    return set(tuple(sorted(int(v) for v in i)) for i in result)


def t(f):
    # Wrap the function to evaluate generators and to pass in the args
    fw = lambda: list(f(Sn, Si))
    # Validate this solution before benchmarking
    assert expected == make_comparable(fw())
    # Benchmark
    count, time_taken = timeit.Timer(fw).autorange()
    print(f"{f.__name__:<25} {count / time_taken:>10.2f} iter/s")

expected = make_comparable(enumerate_sets_orig(Sn, Si))

t(enumerate_sets_orig)
t(enumerate_sets_2_iter)

Upvotes: 1

Related Questions