Reputation: 2921
When I was struggling to do Problem 14 in Project Euler, I discovered that I could use a thing called memoization to speed up my process (I let it run for a good 15 minutes, and it still hadn't returned an answer). The thing is, how do I implement it? I've tried to, but I get a keyerror(the value being returned is invalid). This bugs me because I am positive I can apply memoization to this and get this faster.
lookup = {}
def countTerms(n):
arg = n
count = 1
while n is not 1:
count += 1
if not n%2:
n /= 2
else:
n = (n*3 + 1)
if n not in lookup:
lookup[n] = count
return lookup[n], arg
print max(countTerms(i) for i in range(500001, 1000000, 2))
Thanks.
Upvotes: 6
Views: 4705
Reputation: 11
This is my solution to PE14:
memo = {1:1}
def get_collatz(n):
if n in memo : return memo[n]
if n % 2 == 0:
terms = get_collatz(n/2) + 1
else:
terms = get_collatz(3*n + 1) + 1
memo[n] = terms
return terms
compare = 0
for x in xrange(1, 999999):
if x not in memo:
ctz = get_collatz(x)
if ctz > compare:
compare = ctz
culprit = x
print culprit
Upvotes: 0
Reputation: 44634
The point of memoising, for the Collatz sequence, is to avoid calculating parts of the list that you've already done. The remainder of a sequence is fully determined by the current value. So we want to check the table as often as possible, and bail out of the rest of the calculation as soon as we can.
def collatz_sequence(start, table={}): # cheeky trick: store the (mutable) table as a default argument
"""Returns the Collatz sequence for a given starting number"""
l = []
n = start
while n not in l: # break if we find ourself in a cycle
# (don't assume the Collatz conjecture!)
if n in table:
l += table[n]
break
elif n%2 == 0:
l.append(n)
n = n//2
else:
l.append(n)
n = (3*n) + 1
table.update({n: l[i:] for i, n in enumerate(l) if n not in table})
return l
Is it working? Let's spy on it to make sure the memoised elements are being used:
class NoisyDict(dict):
def __getitem__(self, item):
print("getting", item)
return dict.__getitem__(self, item)
def collatz_sequence(start, table=NoisyDict()):
# etc
In [26]: collatz_sequence(5)
Out[26]: [5, 16, 8, 4, 2, 1]
In [27]: collatz_sequence(5)
getting 5
Out[27]: [5, 16, 8, 4, 2, 1]
In [28]: collatz_sequence(32)
getting 16
Out[28]: [32, 16, 8, 4, 2, 1]
In [29]: collatz_sequence.__defaults__[0]
Out[29]:
{1: [1],
2: [2, 1],
4: [4, 2, 1],
5: [5, 16, 8, 4, 2, 1],
8: [8, 4, 2, 1],
16: [16, 8, 4, 2, 1],
32: [32, 16, 8, 4, 2, 1]}
Edit: I knew it could be optimised! The secret is that there are two places in the function (the two return points) that we know l
and table
share no elements. While previously I avoided calling table.update
with elements already in table
by testing them, this version of the function instead exploits our knowledge of the control flow, saving lots of time.
[collatz_sequence(x) for x in range(500001, 1000000)]
now times around 2 seconds on my computer, while a similar expression with @welter's version clocks in 400ms. I think this is because the functions don't actually compute the same thing - my version generates the whole sequence, while @welter's just finds its length. So I don't think I can get my implementation down to the same speed.
def collatz_sequence(start, table={}): # cheeky trick: store the (mutable) table as a default argument
"""Returns the Collatz sequence for a given starting number"""
l = []
n = start
while n not in l: # break if we find ourself in a cycle
# (don't assume the Collatz conjecture!)
if n in table:
table.update({x: l[i:] for i, x in enumerate(l)})
return l + table[n]
elif n%2 == 0:
l.append(n)
n = n//2
else:
l.append(n)
n = (3*n) + 1
table.update({x: l[i:] for i, x in enumerate(l)})
return l
PS - spot the bug!
Upvotes: 3
Reputation: 166
There is also a nice recursive way to do this, which probably will be slower than poorsod's solution, but it is more similar to your initial code, so it may be easier for you to understand.
lookup = {}
def countTerms(n):
if n not in lookup:
if n == 1:
lookup[n] = 1
elif not n % 2:
lookup[n] = countTerms(n / 2)[0] + 1
else:
lookup[n] = countTerms(n*3 + 1)[0] + 1
return lookup[n], n
print max(countTerms(i) for i in range(500001, 1000000, 2))
Upvotes: 5