Tetramputechture
Tetramputechture

Reputation: 2921

Python - Memoization and Collatz Sequence

When I was struggling to do Problem 14 in Project Euler, I discovered that I could use a thing called memoization to speed up my process (I let it run for a good 15 minutes, and it still hadn't returned an answer). The thing is, how do I implement it? I've tried to, but I get a keyerror(the value being returned is invalid). This bugs me because I am positive I can apply memoization to this and get this faster.

lookup = {}

def countTerms(n):
   arg = n
   count = 1
   while n is not 1:
      count += 1
      if not n%2:
         n /= 2
      else:
         n = (n*3 + 1)
      if n not in lookup:
         lookup[n] = count

   return lookup[n], arg

print max(countTerms(i) for i in range(500001, 1000000, 2)) 

Thanks.

Upvotes: 6

Views: 4705

Answers (3)

kaffuffle
kaffuffle

Reputation: 11

This is my solution to PE14:

memo = {1:1}
def get_collatz(n):

if n in memo : return memo[n]

if n % 2 == 0:
    terms = get_collatz(n/2) + 1
else:
    terms = get_collatz(3*n + 1) + 1

memo[n] = terms
return terms

compare = 0
for x in xrange(1, 999999):
if x not in memo:
    ctz = get_collatz(x)
    if ctz > compare:
     compare = ctz
     culprit = x

print culprit

Upvotes: 0

Benjamin Hodgson
Benjamin Hodgson

Reputation: 44634

The point of memoising, for the Collatz sequence, is to avoid calculating parts of the list that you've already done. The remainder of a sequence is fully determined by the current value. So we want to check the table as often as possible, and bail out of the rest of the calculation as soon as we can.

def collatz_sequence(start, table={}):  # cheeky trick: store the (mutable) table as a default argument
    """Returns the Collatz sequence for a given starting number"""
    l = []
    n = start

    while n not in l:  # break if we find ourself in a cycle
                       # (don't assume the Collatz conjecture!)
        if n in table:
            l += table[n]
            break
        elif n%2 == 0:
            l.append(n)
            n = n//2
        else:
            l.append(n)
            n = (3*n) + 1

    table.update({n: l[i:] for i, n in enumerate(l) if n not in table})

    return l

Is it working? Let's spy on it to make sure the memoised elements are being used:

class NoisyDict(dict):
    def __getitem__(self, item):
        print("getting", item)
        return dict.__getitem__(self, item)

def collatz_sequence(start, table=NoisyDict()):
    # etc



In [26]: collatz_sequence(5)
Out[26]: [5, 16, 8, 4, 2, 1]

In [27]: collatz_sequence(5)
getting 5
Out[27]: [5, 16, 8, 4, 2, 1]

In [28]: collatz_sequence(32)
getting 16
Out[28]: [32, 16, 8, 4, 2, 1]

In [29]: collatz_sequence.__defaults__[0]
Out[29]: 
{1: [1],
 2: [2, 1],
 4: [4, 2, 1],
 5: [5, 16, 8, 4, 2, 1],
 8: [8, 4, 2, 1],
 16: [16, 8, 4, 2, 1],
 32: [32, 16, 8, 4, 2, 1]}

Edit: I knew it could be optimised! The secret is that there are two places in the function (the two return points) that we know l and table share no elements. While previously I avoided calling table.update with elements already in table by testing them, this version of the function instead exploits our knowledge of the control flow, saving lots of time.

[collatz_sequence(x) for x in range(500001, 1000000)] now times around 2 seconds on my computer, while a similar expression with @welter's version clocks in 400ms. I think this is because the functions don't actually compute the same thing - my version generates the whole sequence, while @welter's just finds its length. So I don't think I can get my implementation down to the same speed.

def collatz_sequence(start, table={}):  # cheeky trick: store the (mutable) table as a default argument
    """Returns the Collatz sequence for a given starting number"""
    l = []
    n = start

    while n not in l:  # break if we find ourself in a cycle
                       # (don't assume the Collatz conjecture!)
        if n in table:
            table.update({x: l[i:] for i, x in enumerate(l)})
            return l + table[n]
        elif n%2 == 0:
            l.append(n)
            n = n//2
        else:
            l.append(n)
            n = (3*n) + 1

    table.update({x: l[i:] for i, x in enumerate(l)})
    return l

PS - spot the bug!

Upvotes: 3

welter
welter

Reputation: 166

There is also a nice recursive way to do this, which probably will be slower than poorsod's solution, but it is more similar to your initial code, so it may be easier for you to understand.

lookup = {}

def countTerms(n):
   if n not in lookup:
      if n == 1:
         lookup[n] = 1
      elif not n % 2:
         lookup[n] = countTerms(n / 2)[0] + 1
      else:
         lookup[n] = countTerms(n*3 + 1)[0] + 1

   return lookup[n], n

print max(countTerms(i) for i in range(500001, 1000000, 2))

Upvotes: 5

Related Questions