Junuxx
Junuxx

Reputation: 14271

How do I iterate over a large number of tuples of integers in the order of their sum?

I'm using itertools.combinations() to iterate over tuples of integers.

I am interested in the tuple with the lowest sum that satisfies some conditions:

def findLowestNiceTuple:
    for tup in itertools.combinations(range(1, 6), 2):
        if niceTuple(tup):
            return tup

The generator's default order is not in the order of the elements' sum. For example:

>>> itertools.combinations(range(1, 6), 2)

gives a generator which will yield the following elements:

[(1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]

As you can see, the sum of (1, 5) is larger than that of (2,3). For early termination, I need the tuples in the order ..., (1, 4), (2, 3), (1, 5), ....

For a modest number of combinations, you can get around this by using sorted():

>>> sorted(itertools.combinations(range(1, 6), 2), key=sum)
[(1, 2), (1, 3), (1, 4), (2, 3), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]

However, sorted() converts the generator to a list which is kept in memory entirely. This means that it no longer scales very well. Something like itertools.combinations(range(1, 600), 400) will inevitably produce a MemoryError.

Is there a more memory-friendly way to achieve the desired result?

PS: I do realize that it would take ages to fully iterate over the last sequence I mentioned, but the tuple I am looking for should be very close to the start. And if I can count on the order, I can terminate early as in the first snippet.

Upvotes: 3

Views: 1873

Answers (2)

Alain T.
Alain T.

Reputation: 42133

You can obtain a pure iterator in O(1) space by going through the range of possible sums rather than generating combinations. From the sum values, yield the sub-range of number pairs that produce it:

def sumPairs(minVal,maxVal):
    for total in range(minVal*2,maxVal*2+1):
        for a in range(max(total-maxVal,minVal),min(total-minVal,maxVal)+1):
            if a >= total-a: continue # to skip permutations
            yield (a,total-a)

output:

for pair in sumPairs(1,6):
    print(pair,sum(pair))

(1, 2) 3
(1, 3) 4
(1, 4) 5
(2, 3) 5
(1, 5) 6
(2, 4) 6
(1, 6) 7
(2, 5) 7
(3, 4) 7
(2, 6) 8
(3, 5) 8
(3, 6) 9
(4, 5) 9
(4, 6) 10
(5, 6) 11

[EDIT] generalization for n-tuples

To go beyond pairs while maintaining an O(1) space complexity is a little trickier. I managed to obtain a space complexity of O(S) where S is the size of the tuple. So the solution's memory space consumption is independent of the range of numbers.

The same strategy is used to traverse the combinations based on the expected sum but, for each sum value, there are multiple combinations that produce the same sum. These combinations can be generated by starting from a base combinations that starts with the smallest numbers and ends with the largest one that will reach that total. That is the widest spread of values possible for the given sum. All other combinations of values are built by gradually pulling down the last value while offsetting the smaller ones upward to compensate for the decrease from reducing the last value.

# generate offsets in increasing order (>=)
# to produce a total value
def getOffsets(size,total,maxValue):
    #print(size,total,maxValue)
    if not total: yield [0]*size; return
    if size == 1 and total==maxValue: yield [maxValue]; return
    while total>=0 and size*maxValue>=total:
        for prefix in getOffsets(size-1,total-maxValue,maxValue):
            yield prefix + [maxValue]
        maxValue -= 1

# generate all combinations of a range of values
# that produce a given total
def comboOfSum(total,size,minValue,maxValue):
    if size == 1: yield (total,); return
    base        = list(range(minValue,minValue+size)) # start with smallest(s)
    base[-1]    = min(total-sum(base[:-1]),maxValue)  # end with largest
    maxOffset   = base[-1]-base[-2]-1 # freedom of moving smaller values
    totalOffset = total-sum(base)     # compensate decreasing last
    minLast     = (total + size*(size-1)//2)//size # minimum to reach total
    while base[-1]>base[-2] and base[-1] >= minLast:
        for offsets in getOffsets(size-1,totalOffset,maxOffset):
            yield tuple(b+o for b,o in zip(base,offsets+[0])) # apply offsets
        base[-1]    -= 1 # decrease last value
        totalOffset += 1 # increase total to compensate for decrease
        maxOffset   -= 1 # decrease small values' freedom of movement

# generate combinations in order of target sum  
def comboBySum(size,minValue,maxValue):
    minTotal = minValue*size + size*(size-1)//2
    maxTotal = maxValue*size - size*(size-1)//2
    for total in range(minTotal,maxTotal+1):
        yield from comboOfSum(total,size,minValue,maxValue)

Validation: (comparing to sorted combinations)

size   = 4
minVal = 10
maxVal = 80

from itertools import combinations

A = list(comboBySum(size,minVal,maxVal))
B = list(sorted(combinations(range(minVal,maxVal+1),size),key=sum))

print("same content:",set(A)==set(B))               # True
print("order by sum:",[*map(sum,A)]==[*map(sum,B)]) # True

output (small scale):

for combo in comboBySum(2,1,6):print(combo,sum(combo))

(1, 2) 3
(1, 3) 4
(1, 4) 5
(2, 3) 5
(1, 5) 6
(2, 4) 6
(1, 6) 7
(2, 5) 7
(3, 4) 7
(2, 6) 8
(3, 5) 8
(3, 6) 9
(4, 5) 9
(4, 6) 10
(5, 6) 11

output (large scale):

for i,combo in enumerate(comboBySum(400,1,800)):
    print(*combo[:5],"...",*combo[-5:],"sum =",sum(combo))
    if i>20: break

1 2 3 4 5 ... 396 397 398 399 400 sum = 80200
1 2 3 4 5 ... 396 397 398 399 401 sum = 80201
1 2 3 4 5 ... 396 397 398 399 402 sum = 80202
1 2 3 4 5 ... 396 397 398 400 401 sum = 80202
1 2 3 4 5 ... 396 397 398 399 403 sum = 80203
1 2 3 4 5 ... 396 397 398 400 402 sum = 80203
1 2 3 4 5 ... 396 397 399 400 401 sum = 80203
1 2 3 4 5 ... 396 397 398 399 404 sum = 80204
1 2 3 4 5 ... 396 397 398 400 403 sum = 80204
1 2 3 4 5 ... 396 397 398 401 402 sum = 80204
1 2 3 4 5 ... 396 397 399 400 402 sum = 80204
1 2 3 4 5 ... 396 398 399 400 401 sum = 80204
1 2 3 4 5 ... 396 397 398 399 405 sum = 80205
1 2 3 4 5 ... 396 397 398 400 404 sum = 80205
1 2 3 4 5 ... 396 397 398 401 403 sum = 80205
1 2 3 4 5 ... 396 397 399 400 403 sum = 80205
1 2 3 4 5 ... 396 397 399 401 402 sum = 80205
1 2 3 4 5 ... 396 398 399 400 402 sum = 80205
1 2 3 4 5 ... 397 398 399 400 401 sum = 80205
1 2 3 4 5 ... 396 397 398 399 406 sum = 80206
1 2 3 4 5 ... 396 397 398 400 405 sum = 80206
1 2 3 4 5 ... 396 397 398 401 404 sum = 80206

output (large number range):

for i,combo in enumerate(comboBySum(20,12345,1000000)):
    print(*combo[:5],"...",*combo[-5:],"sum =",sum(combo))
    if i>20: break

12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12364 sum = 247090
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12365 sum = 247091
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12366 sum = 247092
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12365 sum = 247092
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12367 sum = 247093
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12366 sum = 247093
12345 12346 12347 12348 12349 ... 12360 12361 12363 12364 12365 sum = 247093
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12368 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12367 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12362 12365 12366 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12363 12364 12366 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12362 12363 12364 12365 sum = 247094
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12369 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12368 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12362 12365 12367 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12363 12364 12367 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12363 12365 12366 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12362 12363 12364 12366 sum = 247095
12345 12346 12347 12348 12349 ... 12361 12362 12363 12364 12365 sum = 247095
12345 12346 12347 12348 12349 ... 12360 12361 12362 12363 12370 sum = 247096
12345 12346 12347 12348 12349 ... 12360 12361 12362 12364 12369 sum = 247096
12345 12346 12347 12348 12349 ... 12360 12361 12362 12365 12368 sum = 247096

Upvotes: 1

Blckknght
Blckknght

Reputation: 104762

Here's how I'd solve it, with a recursive function that finds all combinations that sum to a given value:

def ordered_combinations(pop, n):
    pop = sorted(pop)

    for s in range(sum(pop[:n]), sum(pop[-n:])+1):
        yield from get_sums(pop, s, n)

def get_sums(pop, s, n):
    if n == 1:
        if s in pop:
            yield [s]
        return

    for i, v in enumerate(pop):
        if sum(pop[i:i+n]) > s:
            return
        for rest in get_sums(pop[i+1:], s-v, n-1):
            rest.append(v)
            yield rest

Here's an example of it's output:

>>> for c in ordered_combinations(range(1, 8), 4):
    print(c, sum(c))


[4, 3, 2, 1] 10
[5, 3, 2, 1] 11
[6, 3, 2, 1] 12
[5, 4, 2, 1] 12
[7, 3, 2, 1] 13
[6, 4, 2, 1] 13
[5, 4, 3, 1] 13
[7, 4, 2, 1] 14
[6, 5, 2, 1] 14
[6, 4, 3, 1] 14
[5, 4, 3, 2] 14
[7, 5, 2, 1] 15
[7, 4, 3, 1] 15
[6, 5, 3, 1] 15
[6, 4, 3, 2] 15
[7, 6, 2, 1] 16
[7, 5, 3, 1] 16
[6, 5, 4, 1] 16
[7, 4, 3, 2] 16
[6, 5, 3, 2] 16
[7, 6, 3, 1] 17
[7, 5, 4, 1] 17
[7, 5, 3, 2] 17
[6, 5, 4, 2] 17
[7, 6, 4, 1] 18
[7, 6, 3, 2] 18
[7, 5, 4, 2] 18
[6, 5, 4, 3] 18
[7, 6, 5, 1] 19
[7, 6, 4, 2] 19
[7, 5, 4, 3] 19
[7, 6, 5, 2] 20
[7, 6, 4, 3] 20
[7, 6, 5, 3] 21
[7, 6, 5, 4] 22

The combinations are always yielded with the biggest values first, as an artifact of how I'm building them as lists (by appending small values on the end, rather than by concatenating to the front). If you want them ordered from smallest to largest, you can change the rest.append(v); yield rest lines to yield [v]+rest.

The code uses the yield from syntax that was introduced with Python 3.3. If you're using an earlier version that doesn't support that, you can use this equivalent code:

for v in get_sums(pop, s, n):
    yield v

The code can even handle the extreme case you described of 400-combinations taken from an 800 member range. Here's the first twenty results of that computation (shown only with their largest 10 values, since the rest are all identically 390 down to 1), and their sums:

>>> for i, v in enumerate(ordered_combinations(range(1, 800), 400)):
    if i >= 20:
        break
    print(v[:10], sum(v))


[400, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80200
[401, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80201
[402, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80202
[401, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80202
[403, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80203
[402, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80203
[401, 400, 399, 397, 396, 395, 394, 393, 392, 391] 80203
[404, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80204
[403, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80204
[402, 401, 398, 397, 396, 395, 394, 393, 392, 391] 80204
[402, 400, 399, 397, 396, 395, 394, 393, 392, 391] 80204
[401, 400, 399, 398, 396, 395, 394, 393, 392, 391] 80204
[405, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80205
[404, 400, 398, 397, 396, 395, 394, 393, 392, 391] 80205
[403, 401, 398, 397, 396, 395, 394, 393, 392, 391] 80205
[403, 400, 399, 397, 396, 395, 394, 393, 392, 391] 80205
[402, 401, 399, 397, 396, 395, 394, 393, 392, 391] 80205
[402, 400, 399, 398, 396, 395, 394, 393, 392, 391] 80205
[401, 400, 399, 398, 397, 395, 394, 393, 392, 391] 80205
[406, 399, 398, 397, 396, 395, 394, 393, 392, 391] 80206

Because it's recursive, this code may fail if you request an 1000-combination (this is due to Python's default recursion limit). You can modify the limit it with sys.setrecursionlimit if necessary.

It may also have memory issues if you go exceedingly deep with an extremely large population, since get_sums slices (and so copies) the population in the recursive step. If your use for this code will only be using ranges, you can probably fix the memory issue by removing the pop = sorted(pop) line from ordered_combinations, since Python 3's range objects can be sliced efficiently (that is, range(1,100)[10:] is range(11,100)).

Upvotes: 3

Related Questions