Basil C.
Basil C.

Reputation: 147

Algorithm for itertools.combinations in Python

I was solving a programming puzzle involving combinations. It led me to a wonderful itertools.combinations function and I'd like to know how it works under the hood. Documentation says that the algorithm is roughly equivalent to the following:

def combinations(iterable, r):
    # combinations('ABCD', 2) --> AB AC AD BC BD CD
    # combinations(range(4), 3) --> 012 013 023 123
    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))
    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)

I got the idea: we start with the most obvious combination (r first consecutive elements). Then we change one (last) item to get each subsequent combination.

The thing I'm struggling with is a conditional inside for loop.

for i in reversed(range(r)):
    if indices[i] != i + n - r:
        break

This experession is very terse, and I suspect it's where all the magic happens. Please, give me a hint so I could figure it out.

Upvotes: 13

Views: 3563

Answers (3)

user2725093
user2725093

Reputation: 221

Source code has some additional information about what is going on.

The yeild statement before while loop returns a trivial combination of elements (which is simply first r elements of A, (A[0], ..., A[r-1])) and prepares indices for future work. Let's say that we have A='ABCDE' and r=3. Then, after the first step the value of indices is [0, 1, 2], which points to ('A', 'B', 'C').

Let's look at the source code of the loop in question:

2160            /* Scan indices right-to-left until finding one that is not
2161               at its maximum (i + n - r). */
2162            for (i=r-1 ; i >= 0 && indices[i] == i+n-r ; i--)
2163                ;

This loop searches for the rightmost element of indices that hasn't reached its maximum value yet. After the very first yield statement the value of indices is [0, 1, 2]. Therefore, for loop terminates at indices[2].

Next, the following code increments the ith element of indices:

2170            /* Increment the current index which we know is not at its
2171               maximum.  Then move back to the right setting each index
2172               to its lowest possible value (one higher than the index
2173               to its left -- this maintains the sort order invariant). */
2174            indices[i]++;

As a result, we get index combination [0, 1, 3], which points to ('A', 'B', 'D').

Then we roll back the subsequent indices if they are too big:

2175            for (j=i+1 ; j<r ; j++)
2176                indices[j] = indices[j-1] + 1;

Indices increase step by step:

step indices

  1. (0, 1, 2)
  2. (0, 1, 3)
  3. (0, 1, 4)
  4. (0, 2, 3)
  5. (0, 2, 4)
  6. (0, 3, 4)
  7. (1, 2, 3) ...

Upvotes: 2

Uriel
Uriel

Reputation: 16174

This for loop does a simple thing: it checks whether the algorithm should terminate.

The algorithm start with the first r items and increases until it reaches the last r items in the iterable, which are [Sn-r+1 ... Sn-1, Sn] (if we let S be the iterable).

Now, the algorithm scans every item in the indices, and make sure they still have where to go - so it verifies the ith indice is not the index n - r + i, which by the previous paragraph is the (we ignore the 1 here because lists are 0-based).

If all of these indices are equal to the last r positions - then it goes into the else, commiting the return and terminating the algorithm.


We could create the same functionality by using

if indices == list(range(n-r, n)): return

but the main reason for this "mess" (using reverse and break) is that the first index from the end that doesn't match is saved inside i and is used for the next level of the algorithm which increments this index and takes care of re-setting the rest.


You could check this by replacing the yields with

print('Combination: {}  Indices: {}'.format(tuple(pool[i] for i in indices), indices))

Upvotes: 3

user2390182
user2390182

Reputation: 73450

The loop has two purposes:

  1. Terminating if the last index-list has been reached
  2. Determining the right-most position in the index-list that can be legally increased. This position is then the starting point for resetting all indeces to the right.

Let us say you have an iterable over 5 elements, and want combinations of length 3. What you essentially need for this is to generate lists of indexes. The juicy part of the above algorithm generates the next such index-list from the current one:

# obvious 
index-pool:       [0,1,2,3,4]
first index-list: [0,1,2]
                  [0,1,3]
                  ...
                  [1,3,4]
last index-list:  [2,3,4]

i + n - r is the max value for index i in the index-list:

 index 0: i + n - r = 0 + 5 - 3 = 2 
 index 1: i + n - r = 1 + 5 - 3 = 3
 index 2: i + n - r = 2 + 5 - 3 = 4
 # compare last index-list above

=>

for i in reversed(range(r)):
    if indices[i] != i + n - r:
        break
else:
    break

This loops backwards through the current index-list and stops at the first position that doesn't hold its maximum index-value. If all positions hold their maximum index-value, there is no further index-list, thus return.

In the general case of [0,1,4] one can verify that the next list should be [0,2,3]. The loop stops at position 1, the subsequent code

indices[i] += 1

increments the value for indeces[i] (1 -> 2). Finally

for j in range(i+1, r):
    indices[j] = indices[j-1] + 1

resets all positions > i to the smallest legal index-values, each 1 larger than its predecessor.

Upvotes: 5

Related Questions