Albert G Lieu
Albert G Lieu

Reputation: 911

How does loop work in decorator (memoization)

Memoization is a powerful tool. I try to understand the fundamental mechanism, but it seems not to work in the way how I thought. Could anyone explain how it works in the following code in detail?

def memoize(f):
    memo = {}
    def helper(x):
        if x not in memo:            
            memo[x] = f(x)
        print(memo)
        return memo[x]

    return helper

@memoize
def fib(n):
    if n < 2:
        return n
    else:
        return fib(n-1) + fib(n-2)

What really confuses me is when the decorator memoize comes into work in this example. According to a tutorial, it seems the whole function to be decorated runs in the decorator. Here the function is fib(n). If so, how is the loop in the fib(n) handled in decorator memoize(f)?

Let's take fib(4) as an example to demystify the process:

In [1]: fib(4)
{1: 1}
{1: 1, 0: 0}
{1: 1, 0: 0, 2: 1}
{1: 1, 0: 0, 2: 1}
{1: 1, 0: 0, 2: 1, 3: 2}
{1: 1, 0: 0, 2: 1, 3: 2}
{1: 1, 0: 0, 2: 1, 3: 2, 4: 3}

Why is the first value printed out in memoize(f) is {1: 1}? I expect memoize(f)to store memo = {4 : f(4)} at the very beginning even though the the value of f(4) was not known yet at that moment. I know I was wrong. Could anyone one explain how we get these output and how the loop in fib(n) works in memoize(f)?

Thanks a lot.

Upvotes: 0

Views: 230

Answers (2)

Sylvaus
Sylvaus

Reputation: 884

First, the easiest way to understand the decorators (without parameter) is to see the following equivalence

@memoize
def f():
    ...

# is the same as


def f():
   ...
f = memoize(f)

Thus the code

@memoize
def fib(n):
    if n < 2:
        return n
    else:
        return fib(n-1) + fib(n-2)

is equivalent to

def not_decorated_fib(n):
    if n < 2:
        return n
    else:
        return fib(n-1) + fib(n-2)

def fib(x):
    if x not in memo:            
        memo[x] = not_decorated_fib(x)
    print(memo)
    return memo[x]

Which means you will have the following call stack:

  • Stack level 0: fib(4) calls not_decorated_fib(4) since memo is empty
  • Stack level 1: not_decorated_fib(4) calls fib(3) (+ fib(2) will only be called after fib(3) has been resolved )
  • Stack level 2: fib(3) calls not_decorated_fib(3) since memo is still empty
  • Stack level 3: not_decorated_fib(3) calls fib(2)
  • Stack level 4: fib(2) calls not_decorated_fib(2) since memo is still empty
  • Stack level 5: not_decorated_fib(2) calls fib(1)
  • Stack level 6: fib(1) calls not_decorated_fib(1) since memo is still empty
  • Stack level 7: not_decorated_fib(1) returns 1 (end condition)
  • Stack level 6: fib(1): memo[1] is set since the not_decorated_fib(1), we get the first print {1:1} and fib(1) returns 1
  • Stack level 5: not_decorated_fib(2) calls fib(0) (evaluation of the second part of fib(1) + fib(0))
  • Stack level 6: fib(0) calls not_decorated_fib(0)
  • Stack level 7: not_decorated_fib(0) returns 0 (end condition)
  • Stack level 6: fib(0): memo[0] is set since the not_decorated_fib(0), we get the second print {1:1, 0:0} and fib(0) returns 0

and it similar fashion the rest of the function is executed

Upvotes: 1

Samwise
Samwise

Reputation: 71477

The memo cache doesn't get populated until the function call returns:

memo[x] = f(x)

Since the loop is recursive, there are a bunch more calls to f before that first f(4) finishes returning and populates the cache. The first one of those calls to actually return is f(1), followed by f(0), etc (as seen in your print statements).

If you were to add another print at the start of helper (before you call f) then you'd see the recursive calls as a sandwich, with f(4) starting first but finishing last.

Here's how you could modify the print statements to show the recursion depth as well:

def memoize(f):
    memo = {}
    depth = [0]
    def helper(x):
        print(f"{'  '*depth[0]}Calling f({x})...")
        depth[0] += 1
        if x not in memo:            
            memo[x] = f(x)
        print(f"{'  '*depth[0]}Cached: {memo}")
        depth[0] -= 1
        print(f"{'  '*depth[0]}Finished f({x})!")
        return memo[x]

    return helper

@memoize
def fib(n):
    if n < 2:
        return n
    else:
        return fib(n-1) + fib(n-2)

prints:

Calling f(4)...
  Calling f(3)...
    Calling f(2)...
      Calling f(1)...
        Cached: {1: 1}
      Finished f(1)!
      Calling f(0)...
        Cached: {1: 1, 0: 0}
      Finished f(0)!
      Cached: {1: 1, 0: 0, 2: 1}
    Finished f(2)!
    Calling f(1)...
      Cached: {1: 1, 0: 0, 2: 1}
    Finished f(1)!
    Cached: {1: 1, 0: 0, 2: 1, 3: 2}
  Finished f(3)!
  Calling f(2)...
    Cached: {1: 1, 0: 0, 2: 1, 3: 2}
  Finished f(2)!
  Cached: {1: 1, 0: 0, 2: 1, 3: 2, 4: 3}
Finished f(4)!

Upvotes: 3

Related Questions