Robin Andrews
Robin Andrews

Reputation: 3794

Factorial Time Complexity for Permutations

I just want to check whether the following code has factorial time complexity. I.e. O(n!) if n is the number of characters in my_str. From my understanding it has, but I might have missed something.

def perms(a_str):
    stack = list(a_str)
    results = [stack.pop()]
    while stack:
        current = stack.pop()
        new_results = []
        for partial in results:
            for i in range(len(partial) + 1):
                new_results.append(partial[:i] + current + partial[i:])
        results = new_results
    return results


my_str = "ABCDEFGHIJ"
print(perms(my_str))

Upvotes: 0

Views: 331

Answers (1)

inordirection
inordirection

Reputation: 987

The complexity is actually O((n+1)!), which although pretty comparable to O(n!) is a distinctly greater complexity class than it.

Putting the algorithm into terms amenable to its run-time analysis, its iteratively generating all permutations of every suffix of the input string until the last iteration of the while loop completes where it will have built permutations of the entire input string.

Before the loop, you generate all permutations of the final, length 1 suffix (just the list containing all 1! = 1 permutations of the final character). For simplicity, consider this the 1st iteration of the loop

Then, during the k-th iteration of the loop, you effectively use all previous permutations of the suffix a_str[n-k+1:] to build the permutations for the incrementally larger suffix, a_str[n-k:] by putting the character at index n-k in all possible positions for each partial permutation you've already built. The total work done on each iteration is proportional to the total length of all new partial permutation strings being generated during that iteration, which is the length of each partial permutation, k, times the number of partial permutations, k!: k*k!.

Considering that k can range from 1 (when generating the initial single permutation of the last character) to n (during the last iteration responsible for generating all of the n! permutations which ultimately appear in the output), the total work done over the course of the entire algorithm can be given by the simple sum:

enter image description here

When you solve this sum, representing the total length of all partial permutations built over the course of the algorithm, you get:

enter image description here

The optimal run-time of a permutation generating algorithm would be O(n*n!), because that is the total length of the output array you need to produce. However, O((n+1)!) = O(n*n!) because:

O((n+1)!) = O((n+1)n!) = O(n*n! + n!) = O(n*n!)

This means the above algorithm is still asymptotically optimal, even if it does do a bit of unnecessary work in building partial permutations which don't themselves directly figure into the final output (such that permutation generating algorithms based on swapping elements rather than iteratively building partial permutations can be marginally faster).

You can check my math with this instrumented version of the algorithm and some test cases:

def perms(a_str):
    total_cost = 0 # total cost of creating all partial permutation strings
    stack = list(a_str)
    results = [stack.pop()]
    total_cost += len(results[0]) # increment total cost
    while stack:
        current = stack.pop()
        new_results = []
        for partial in results:
            for i in range(len(partial) + 1):
                next = partial[:i] + current + partial[i:]
                total_cost += len(next) # increment total cost
                new_results.append(next)
        results = new_results
    return results, total_cost

from math import factorial

def test(string):
    n = len(string)
    print(f'perms({string}):')
    ps, total_cost = perms(string)
    # print(ps)
    print("optimal cost:", n * factorial(n))
    print("total cost:", total_cost)
    print("expected cost (sum):", sum([k * factorial(k) for k in range(1, n+1)]))
    print("expected cost (closed form):", factorial(n + 1) - 1)
    print()

tests = [''.join([chr(i) for i in range(ord('A'), ord('A')+j)]) for j in range(1, ord('I') - ord('A'))]
for t in tests: test(t)

Upvotes: 1

Related Questions