user7404408
user7404408

Reputation: 19

Python, return statement during backtracking

I was experimenting a little and something happened which I thought was not expected. The question is based on recursion and the commented return statement on line 7

def twentyone(nums, stack = [], answer = set()):
    for index, num in enumerate(nums):
        new_stack = stack + [num]
        total = sum(new_stack)
        if total == 21:
            answer.add(tuple(new_stack))
            #return
        elif total < 21:
            twentyone(nums[index + 1:], new_stack, answer)
    return answer

user_input = input()
list_format = [int(x) for x in user_input.split()]
answer = twentyone(list_format)

if len(answer) == 0:
    print("No combination of numbers add to 21")
for solution in answer:
    print("The values ", end = "")
    for number in solution:
            print("{} ".format(number), end = "")
    print("add up to 21")

My question is, during the test using the example input "1 9 11 5 6", if I have the return statement. The output is only "The values 1 9 11 add up to 21", but without the return statement the output is "The values 1 9 11 add up to 21 The values 1 9 5 6 add up to 21". I was wondering if anything could explain why, I thought that the return statement would simply "speed up" ending this recursive instance of the method instead of simply skipping over the other lines of code which it wont reach and since I have already added the tuple to the mutable set, it would have been added in other recursive instances and thus no problem. But I was of course wrong.

Upvotes: 0

Views: 839

Answers (1)

PM 2Ring
PM 2Ring

Reputation: 55469

That return statement causes the for loop to terminate prematurely, so it may not find all possible solutions at that particular recursion depth.

We can see this by adding an extra list to the function that tracks the for loop indices. Note that I give indices a default value of None - that's the standard way of avoiding the default mutable argument pitfall.

def twentyone(nums, stack = [], answer = set(), indices = None):
    if indices is None:
        indices = []
    for index, num in enumerate(nums):
        new_stack = stack + [num]
        total = sum(new_stack)
        if total == 21:
            print(indices + [index], new_stack)
            answer.add(tuple(new_stack))
            #return
        elif total < 21:
            twentyone(nums[index + 1:], new_stack, answer, indices + [index])

    return answer

user_input = '1 9 11 5 6 3 7'
list_format = [int(x) for x in user_input.split()]
answer = twentyone(list_format)
print('\n', answer)

output

[0, 0, 0] [1, 9, 11]
[0, 0, 1, 0] [1, 9, 5, 6]
[0, 1, 1, 0] [1, 11, 6, 3]
[1, 1, 2] [9, 5, 7]
[2, 2, 0] [11, 3, 7]
[3, 0, 0, 0] [5, 6, 3, 7]

 {(9, 5, 7), (1, 11, 6, 3), (1, 9, 5, 6), (1, 9, 11), (5, 6, 3, 7), (11, 3, 7)}

If we un-comment the return, we get this output:

[0, 0, 0] [1, 9, 11]
[0, 1, 1, 0] [1, 11, 6, 3]
[1, 1, 2] [9, 5, 7]
[2, 2, 0] [11, 3, 7]
[3, 0, 0, 0] [5, 6, 3, 7]

 {(9, 5, 7), (5, 6, 3, 7), (1, 11, 6, 3), (1, 9, 11), (11, 3, 7)}

[0, 0, 1, 0] is missing from the indices, which means that the for loop of [0, 0, 0] was terminated prematurely.


As I mentioned earlier, it can be dangerous to use mutable objects for default arguments. This is discussed extensively at “Least Astonishment” and the Mutable Default Argument.

In this code you don't get a problem because you only have one call of twentyone that uses the default args and the recursive call supplies explicit args. But if your calling code called twentyone a second time with another list of user input, then the default answer will still have the items it collected during the previous call. The stack list is safe since you never mutate it.

Note that the behaviour of default mutable arguments isn't always a pitfall. It's extremely useful for implementing caching; see my answer to Fibonacci in Python for an example.


FWIW, here's a version of twentyone that doesn't have the default mutable argument problem. I've also changed stack & new_stack into tuples. That saves the tuple call; creating a new tuple with new_stack = stack + (num,) is no less efficient than creating a new list with new_stack = stack + [num]. Tuples are slightly more efficient than lists, and there's not much point in using lists for stack / new_stack since they never get mutated.

def twentyone(nums, stack=(), answer=None):
    if answer is None:
        answer = set()
    for index, num in enumerate(nums):
        new_stack = stack + (num,)
        total = sum(new_stack)
        if total == 21:
            answer.add(new_stack)
        elif total < 21:
            twentyone(nums[index + 1:], new_stack, answer)

    return answer

Another way to implement this is as a recursive generator. That way we don't need answer, we just yield solutions as we find them.

def twentyone(nums, stack=()):
    for index, num in enumerate(nums):
        new_stack = stack + (num,)
        total = sum(new_stack)
        if total == 21:
            yield new_stack
        elif total < 21:
            yield from twentyone(nums[index + 1:], new_stack)

user_input = '1 9 11 5 6 3 7'
list_format = [int(x) for x in user_input.split()]
for t in twentyone(list_format):
    print(t)

output

(1, 9, 11)
(1, 9, 5, 6)
(1, 11, 6, 3)
(9, 5, 7)
(11, 3, 7)
(5, 6, 3, 7)

The downside of this is that if there are duplicate items in nums then we get duplicate solutions. But we can easily get around that by running the generator inside set():

print(set(twentyone(list_format)))

output

{(9, 5, 7), (1, 11, 6, 3), (1, 9, 5, 6), (1, 9, 11), (5, 6, 3, 7), (11, 3, 7)}

However, this only eliminates exact duplicates, it doesn't get rid of solutions that are permutations of previous solutions. To make it totally foolproof, we need to sort the output tuples.

def twentyone(nums, stack=()):
    for index, num in enumerate(nums):
        new_stack = stack + (num,)
        total = sum(new_stack)
        if total == 21:
            yield tuple(sorted(new_stack))
        elif total < 21:
            yield from twentyone(nums[index + 1:], new_stack)

user_input = '1 9 5 11 5 6 5'
list_format = [int(x) for x in user_input.split()]
print(set(twentyone(list_format)))

output

{(1, 9, 11), (1, 5, 6, 9), (5, 5, 11), (5, 5, 5, 6)}

Upvotes: 2

Related Questions