Neil_UK
Neil_UK

Reputation: 1083

How can I yield from an arbitrary depth of recursion?

I've written a function to create combinations of inputs of an arbitrary length, so recursion seemed to be the obvious way to do it. While it's OK for a small toy example to return a list of the results, I'd like to yield them instead. I've read about yield from, but don't fully understand how it is used, the examples don't appear to cover my use case, and hoping'n'poking it into my code has not yet produced anything that works. Note that writing this recursive code was at the limit of my python ability, hence the copious debug print statements.

This is the working list return code, with my hopeful non-working yield commented out.

def allposs(elements, output_length):
    """
    return all zero insertion paddings of elements up to output_length maintaining order

    elements -         an iterable of length >= 1
    output_length      >= len(elements)

    for instance allposs((3,1), 4) returns
    [[3,1,0,0], [3,0,1,0], [3,0,0,1], [0,3,1,0], [0,3,0,1], [0,0,3,1]]
    """

    output_list = []

    def place_nth_element(nth, start_at, output_so_far):
        # print('entering place_nth_element with nth =', nth,
        #      ', start_at =', start_at,
        #      ', output_so_far =', output_so_far)
        
        last_pos = output_length - len(elements) + nth
        # print('iterating over range',start_at, 'to', last_pos+1)
        for pos in range(start_at, last_pos+1):
            output = list(output_so_far)           
            # print('placing', elements[nth], 'at position', pos)
            output[pos] = elements[nth]

            if nth == len(elements)-1:
                # print('appending output', output)
                output_list.append(output)
                # yield output    
            else:
                # print('making recursive call')
                place_nth_element(nth+1, pos+1, output)
   
    place_nth_element(0, 0, [0]*output_length)
    return output_list

if __name__=='__main__':
    for q in allposs((3,1), 4):
        print(q)

What is the syntax to use yield from to get my list generated a combination at a time?

Upvotes: 1

Views: 395

Answers (1)

Karl Knechtel
Karl Knechtel

Reputation: 61643

Recursive generators are a powerful tool and I'm glad you're putting in the effort to study them.

What is the syntax to use yield from to get my list generated a combination at a time?

You put yield from in front of the expression from which results should be yielded; in your case, the recursive call. Thus: yield from place_nth_element(nth+1, pos+1, output). The idea is that each result from the recursively-called generator is iterated over (behind the scenes) and yielded at this point in the process.

Note that for this to work:

  • You need to yield the individual results at the base level of the recursion

  • To "collect" the results from the resulting generator, you need to iterate over the result from the top-level call. Fortunately, iteration is built-in in a lot of places; for example, you can just call list and it will iterate for you.

Rather than nesting the recursive generator inside a wrapper function, I prefer to write it as a separate helper function. Since there is no longer a need to access output_list from the recursion, there is no need to form a closure; and flat is better than nested as they say. This does, however, mean that we need to pass elements through the recursion. We don't need to pass output_length because we can recompute it (the length of output_so_far is constant across the recursion).

Also, I find it's helpful, when doing these sorts of algorithms, to think as functionally as possible (in the paradigm sense - i.e., avoid side effects and mutability, and proceed by creating new objects). You had a workable approach using list to make copies (although it is clearer to use the .copy method), but I think there's a cleaner way, as shown below.

All this advice leads us to:

def place_nth_element(elements, nth, start_at, output_so_far):        
    last_pos = len(output_so_far) - len(elements) + nth
    for pos in range(start_at, last_pos+1):
        output = output_so_far[:pos] + (elements[nth],) + output_so_far[pos+1:]
        if nth == len(elements)-1:
            yield output    
        else:
            yield from place_nth_element(elements, nth+1, pos+1, output)


def allposs(elements, output_length):
    return list(place_nth_element(elements, 0, 0, (0,)*output_length))

HOWEVER, I would not solve the problem that way - because the standard library already offers a neat solution: we can find the itertools.combinations of indices where a value should go, and then insert them. Now that we no longer have to think recursively, we can go ahead and mutate values :)

from itertools import combinations

def place_values(positions, values, size):
    result = [0] * size
    for position, value in zip(positions, values):
        result[position] = value
    return tuple(result)


def possibilities(values, size):
    return [
        place_values(positions, values, size)
        for positions in combinations(range(size), len(values))
    ]

Upvotes: 2

Related Questions