Bsh
Bsh

Reputation: 368

Partition of a list of integers into K sublists with equal sum

Similar questions are 1 and 2 but the answers didn't help. Assume we have a list of integers. We want to find K disjoint lists such that they completely cover the given list and all have the same sum. For example, if A = [4, 3, 5, 6, 4, 3, 1] and K = 2 then the answer should be:

[[3, 4, 6], [1, 3, 4, 5]]
or
[[4, 4, 5], [1, 3, 3, 6]]

I have written a code that only works when K = 2 and it works fine with small lists as input but with very larger lists, because of the code's high complexity, OS terminates the task. My code is:

def subarrays_equal_sum(l):
    from itertools import combinations

    if len(l) < 2 or sum(l) % 2 != 0:
        return []
    l = sorted(l)
    list_sum = sum(l)
    all_combinations = []
    for i in range(1, len(l)):
        all_combinations += (list(combinations(l, i)))

    combinations_list = [i for i in all_combinations if sum(i) == list_sum / 2]
    if not combinations_list:
        return []
    final_result = []
    for i in range(len(combinations_list)):
        for j in range(i + 1, len(combinations_list)):
            first = combinations_list[i]
            second = combinations_list[j]
            concat = sorted(first + second)
            if concat == l and [list(first), list(second)] not in final_result:
                final_result.append([list(first), list(second)])

    return final_result

An answer for any value of K is available here. But if we pass the arguments A = [4, 3, 5, 6, 4, 3, 1] and K = 2, their code only returns [[5, 4, 3, 1],[4, 3, 6]] whereas my code returns all possible lists i.e.,

[[[3, 4, 6], [1, 3, 4, 5]], [[4, 4, 5], [1, 3, 3, 6]]]

My questions are:

  1. How to improve the complexity and cost of my code?
  2. How to make my code work with any value of k?

Upvotes: 6

Views: 963

Answers (2)

btilly
btilly

Reputation: 46399

Here is a solution that deals with duplicates.

First of all the problem of finding any solution is, as noted, NP-complete. So there are cases where this will churn for a long time to realize that there are none. I've applied reasonable heuristics to limit how often this happens. The heuristics can be improved. But be warned that there will be cases that simply nothing works.

The first step in this solution is to take a list of numbers and turn it into [(value1, repeat), (value2, repeat), ...]. One of those heuristics requires that the values be sorted first by descending absolute value, and then by decreasing value. That is because I try to use the first elements first, and we expect a bunch of small leftover numbers to still give us sums.

Next, I'm going to try to split it into a possible maximal subset with the right target sum, and all remaining elements.

Then I'm going to split the remaining into a possible maximal remaining subset that is no bigger than the first, and the ones that result after that.

Do this recursively and we find a solution. Which we yield back up the chain.

But, and here is where it gets tricky, I'm not going to do the split by looking at combinations. Instead I'm going to use dynamic programming like we would for the usual subset-sum pseudo-polynomial algorithm, except I'll use it to construct a data structure from which we can do the split. This data structure will contain the following fields:

  1. value is the value of this element.
  2. repeat is how many times we used it in the subset sum.
  3. skip is how many copies we had and didn't use it in the subset sum.
  4. tail is the tail of these solutions.
  5. prev are some other solutions where we did something else.

Here is a class that constructs this data structure, with a method to split elements into a subset and elements still available for further splitting.

from collections import namedtuple

class RecursiveSums (
      namedtuple('BaseRecursiveSums',
                 ['value', 'repeat', 'skip', 'tail', 'prev'])):

    def sum_and_rest(self):
        if self.tail is None:
            if self.skip:
                yield ([self.value] * self.repeat, [(self.value, self.skip)])
            else:
                yield ([self.value] * self.repeat, [])
        else:
            for partial_sum, rest in self.tail.sum_and_rest():
                for _ in range(self.repeat):
                    partial_sum.append(self.value)
                if self.skip:
                    rest.append((self.value, self.skip))
                yield (partial_sum, rest)
        if self.prev is not None:
            yield from self.prev.sum_and_rest()

You might have to look at this a few times to see how it works.

Next, remember I said that I used a heuristic to try to use large elements before small ones. Here is some code that we'll need to do that comparison.

class AbsComparator(int):
    def __lt__ (self, other):
        if abs(int(self)) < abs(int(other)):
            return True
        elif abs(other) < abs(self):
            return False
        else:
            return int(self) < int(other)

def abs_lt (x, y):
    return AbsComparator(x) < AbsComparator(y)

We'll need both forms. The function for a direct comparison, the class for Python's key argument to the sort function. See Using a comparator function to sort for more on the latter.

And now the heart of the method. This finds all ways to split into a subset (that is no larger than bound in the comparison metric we are using) and the remaining elements to split more.

The idea is the same as the dynamic programming approach to subset sum https://www.geeksforgeeks.org/count-of-subsets-with-sum-equal-to-x/ except with two major differences. The first is that instead of counting the answers we are building up our data structure. The second is that our keys are (partial_sum, bound_index) so we know whether our bound is currently satisfied, and if it is not we know what element to compare next to test it.

def lexically_maximal_subset_rest (elements, target, bound=None):
    """
        elements = [(value, count), (value, count), ...]
            with largest absolute values first.
        target = target sum
        bound = a lexical bound on the maximal subset.
    """
    # First let's deal with all of the trivial cases.
    if 0 == len(elements):
        if 0 == target:
            yield []
    elif bound is None or 0 == len(bound):
        # Set the bound to something that trivially works.
        yield from lexically_maximal_subset_rest(elements, target, [abs(elements[0][0]) + 1])
    elif abs_lt(bound[0], elements[0][0]):
        pass # we automatically use more than the bound.
    else:
        # The trivial checks are done.

        bound_satisfied = (bound[0] != elements[0][0])

        # recurse_by_sum will have a key of (partial_sum, bound_index).
        # If the bound_index is None, the bound is satisfied.
        # Otherwise it will be the last used index in the bound.
        recurse_by_sum = {}
        # Populate it with all of the ways to use the first element at least once.
        (init_value, init_count) = elements[0]
        for i in range(init_count):
            if not bound_satisfied:
                if len(bound) <= i or abs_lt(bound[i], init_value):
                    # Bound exceeded.
                    break
                elif abs_lt(init_value, bound[i]):
                    bound_satisfied = True
            if bound_satisfied:
                key = (init_value * (i+1), None)
            else:
                key = (init_value * (i+1), i)

            recurse_by_sum[key] = RecursiveSums(
                init_value, i+1, init_count-i-1, None, recurse_by_sum.get(key))

        # And now we do the dynamic programming thing.
        for j in range(1, len(elements)):
            value, repeat = elements[j]
            next_recurse_by_sum = {}
            for key, tail in recurse_by_sum.items():
                partial_sum, bound_index = key
                # Record not using this value at all.
                next_recurse_by_sum[key] = RecursiveSums(
                    value, 0, repeat, tail, next_recurse_by_sum.get(key))
                # Now record the rest.
                for i in range(1, repeat+1):
                    if bound_index is not None:
                        # Bounds check.
                        if len(bound) <= bound_index + i:
                            break # bound exceeded.
                        elif abs_lt(bound[bound_index + i], value):
                            break # bound exceeded.
                        elif abs_lt(value, bound[bound_index + i]):
                            bound_index = None # bound satisfied!
                    if bound_index is None:
                        next_key = (partial_sum + value * i, None)
                    else:
                        next_key = (partial_sum + value * i, bound_index + i)

                    next_recurse_by_sum[next_key] = RecursiveSums(
                        value, i, repeat - i, tail, next_recurse_by_sum.get(next_key))
            recurse_by_sum = next_recurse_by_sum

        # We now have all of the answers in recurse_by_sum, but in several keys.
        # Find all that may have answers.
        bound_index = len(bound)
        while 0 < bound_index:
            bound_index -= 1
            if (target, bound_index) in recurse_by_sum:
                yield from recurse_by_sum[(target, bound_index)].sum_and_rest()
        if (target, None) in recurse_by_sum:
            yield from recurse_by_sum[(target, None)].sum_and_rest()

And now we implement the rest.

def elements_split (elements, target, k, bound=None):
    if 0 == len(elements):
        if k == 0:
            yield []
    elif k == 0:
        pass # still have elements left over.
    else:
        for (subset, rest) in lexically_maximal_subset_rest(elements, target, bound):
            for answer in elements_split(rest, target, k-1, subset):
                answer.append(subset)
                yield answer

def subset_split (raw_elements, k):
    total = sum(raw_elements)
    if 0 == (total % k):
        target = total // k
        counts = {}
        for e in sorted(raw_elements, key=AbsComparator, reverse=True):
            counts[e] = 1 + counts.get(e, 0)
        elements = list(counts.items())
        yield from elements_split(elements, target, k)

And here is a demonstration using your list, doubled. Which we split into 4 equal parts. On my laptop it finds all 10 solutions in 0.084 seconds.

n = 0
for s in subset_split([4, 3, 5, 6, 4, 3, 1]*2, 4):
    n += 1
    print(n, s)

So...no performance guarantees. But this should usually be able to find splits pretty quickly per split. Of course there are also usually an exponential number of splits. For example if you take 16 copies of your list and try to split into 32 groups, it takes about 8 minutes on my laptop to find all 224082 solutions.

If I didn't try to deal with negatives, this could be sped up quite a bit. (Use cheaper comparisons, drop all partial sums that have exceeded target to avoid calculating most of the dynamic programming table.)

And here is the sped up version. For the case with only nonnegative numbers it is about twice as fast. If there are negative numbers it will produce wrong results.

from collections import namedtuple

class RecursiveSums (
      namedtuple('BaseRecursiveSums',
                 ['value', 'repeat', 'skip', 'tail', 'prev'])):

    def sum_and_rest(self):
        if self.tail is None:
            if self.skip:
                yield ([self.value] * self.repeat, [(self.value, self.skip)])
            else:
                yield ([self.value] * self.repeat, [])
        else:
            for partial_sum, rest in self.tail.sum_and_rest():
                for _ in range(self.repeat):
                    partial_sum.append(self.value)
                if self.skip:
                    rest.append((self.value, self.skip))
                yield (partial_sum, rest)
        if self.prev is not None:
            yield from self.prev.sum_and_rest()

def lexically_maximal_subset_rest (elements, target, bound=None):
    """
        elements = [(value, count), (value, count), ...]
            with largest absolute values first.
        target = target sum
        bound = a lexical bound on the maximal subset.
    """
    # First let's deal with all of the trivial cases.
    if 0 == len(elements):
        if 0 == target:
            yield []
    elif bound is None or 0 == len(bound):
        # Set the bound to something that trivially works.
        yield from lexically_maximal_subset_rest(elements, target, [abs(elements[0][0]) + 1])
    elif bound[0] < elements[0][0]:
        pass # we automatically use more than the bound.
    else:
        # The trivial checks are done.

        bound_satisfied = (bound[0] != elements[0][0])

        # recurse_by_sum will have a key of (partial_sum, bound_index).
        # If the bound_index is None, the bound is satisfied.
        # Otherwise it will be the last used index in the bound.
        recurse_by_sum = {}
        # Populate it with all of the ways to use the first element at least once.
        (init_value, init_count) = elements[0]
        for i in range(init_count):
            if not bound_satisfied:
                if len(bound) <= i or bound[i] < init_value:
                    # Bound exceeded.
                    break
                elif init_value < bound[i]:
                    bound_satisfied = True
            if bound_satisfied:
                key = (init_value * (i+1), None)
            else:
                key = (init_value * (i+1), i)

            recurse_by_sum[key] = RecursiveSums(
                init_value, i+1, init_count-i-1, None, recurse_by_sum.get(key))

        # And now we do the dynamic programming thing.
        for j in range(1, len(elements)):
            value, repeat = elements[j]
            next_recurse_by_sum = {}
            for key, tail in recurse_by_sum.items():
                partial_sum, bound_index = key
                # Record not using this value at all.
                next_recurse_by_sum[key] = RecursiveSums(
                    value, 0, repeat, tail, next_recurse_by_sum.get(key))
                # Now record the rest.
                for i in range(1, repeat+1):
                    if target < partial_sum + value * i:
                        break # these are too big.

                    if bound_index is not None:
                        # Bounds check.
                        if len(bound) <= bound_index + i:
                            break # bound exceeded.
                        elif bound[bound_index + i] < value:
                            break # bound exceeded.
                        elif value < bound[bound_index + i]:
                            bound_index = None # bound satisfied!
                    if bound_index is None:
                        next_key = (partial_sum + value * i, None)
                    else:
                        next_key = (partial_sum + value * i, bound_index + i)

                    next_recurse_by_sum[next_key] = RecursiveSums(
                        value, i, repeat - i, tail, next_recurse_by_sum.get(next_key))
            recurse_by_sum = next_recurse_by_sum

        # We now have all of the answers in recurse_by_sum, but in several keys.
        # Find all that may have answers.
        bound_index = len(bound)
        while 0 < bound_index:
            bound_index -= 1
            if (target, bound_index) in recurse_by_sum:
                yield from recurse_by_sum[(target, bound_index)].sum_and_rest()
        if (target, None) in recurse_by_sum:
            yield from recurse_by_sum[(target, None)].sum_and_rest()

def elements_split (elements, target, k, bound=None):
    if 0 == len(elements):
        if k == 0:
            yield []
    elif k == 0:
        pass # still have elements left over.
    else:
        for (subset, rest) in lexically_maximal_subset_rest(elements, target, bound):
            for answer in elements_split(rest, target, k-1, subset):
                answer.append(subset)
                yield answer

def subset_split (raw_elements, k):
    total = sum(raw_elements)
    if 0 == (total % k):
        target = total // k
        counts = {}
        for e in sorted(raw_elements, key=AbsComparator, reverse=True):
            counts[e] = 1 + counts.get(e, 0)
        elements = list(counts.items())
        yield from elements_split(elements, target, k)

n = 0
for s in subset_split([4, 3, 5, 6, 4, 3, 1]*16, 32):
    n += 1
    print(n, s)

Upvotes: 4

Alain T.
Alain T.

Reputation: 42143

This has a large number of potential solutions so, reducing the number of eligible patterns to evaluate will be key to improving performance.

here's an idea: Approach it in two steps:

  1. generate a list of indexes groups that add up to the target equal sum.
  2. combine the index groups that don't intersect (so indexes are only in one group) so that you get K groups.

The assemble function is a recursive generator that will produce lists of n index combinations (sets) that don't overlap. given that each group has a sum of total/K, the lists will have full coverage of the original lists elements.

def assemble(combos,n):
    if not n:
        yield []
        return
    if len(combos)<n: return
    for i,idx in enumerate(combos):
        others = [c for c in combos if c.isdisjoint(idx)]
        for rest in assemble(others,n-1):
            yield [idx] + rest
            
def equalSplit(A,K):
    total = sum(A) 
    if total%K: return       # no equal sum solution
    partSum = total//K       # sum of each resulting sub-lists
    combos = [ (0,[]) ]      # groups of indices that form sums <= partSum
    for i,n in enumerate(A): # build the list of sum,patterns 
        combos += [ (tot+n,idx+[i]) for tot,idx in combos
                    if tot+n <= partSum]
    # only keep index sets that add up to the target sum
    combos = [set(idx) for tot,idx in combos if tot == partSum]
    # ouput assembled lists of K sets that don't overlap (full coverage)
    seen = set()
    for parts in assemble(combos,K):
        sol = tuple(sorted(tuple(sorted(A[i] for i in idx)) for idx in parts))
        if sol in seen: continue # skip duplicate solutions
        yield list(sol)
        seen.add(sol)

Output:

A = [4, 3, 5, 6, 4, 3, 1]
print(*equalSplit(A,2), sep='\n')
# [(1, 3, 4, 5), (3, 4, 6)]
# [(1, 3, 3, 6), (4, 4, 5)]

A = [21,22,27,14,15,16,17,18,19,10,11,12,13]
print(*equalSplit(A,5), sep='\n')
# [(10, 15, 18), (11, 13, 19), (12, 14, 17), (16, 27), (21, 22)]
# [(10, 14, 19), (11, 15, 17), (12, 13, 18), (16, 27), (21, 22)]

This will still take a long time for large lists that are split in few parts but it should be a bit better than brute force over combinations

Upvotes: -1

Related Questions