Reputation: 43487
I wanted to write a backtracking solution for this question, which asks to find the most distinct odd numbers that sum up to a given n
.
I hacked together this Python code:
import sys
sys.setrecursionlimit(10000)
stop = False
def solve(n, used, current_sum):
global stop
if stop:
return
if current_sum == n:
print(used)
stop = True
return
start = 1 if len(used) == 0 else (used[-1] + 2)
for i in range(start, n + 1, 2):
if current_sum + i <= n and not stop:
used.append(i)
solve(n, used, current_sum + i)
used.pop()
else:
return
solve(100000000, [], 0)
Which, unfortunately, does not print anything for me. As far as I can tell, it never gets to that if
condition. If I print current_sum
at each step, it seems to just stop at around 16000000
when the entire program quits with no error.
I tried increasing the recursion limit, no luck.
I tested it on Idle and Eclipse, in Python 3.4, 64 bit, under Windows 8.1. I have 16 GB of RAM.
If I reduce n
, then I get a solution (remove a zero for example).
This did not make sense to me, so I wanted to see if maybe I have better luck in C. I hacked this in C:
int sol[100000];
bool done = false;
void solve(int n, int k, int sum)
{
if (done)
return;
if (sum == n)
{
done = true;
for (int i = 0; i < k; ++i)
{
printf("%d ", sol[i]);
}
printf("\n");
return;
}
int start = 1;
if (k > 0)
start = sol[k - 1] + 2;
for (int i = start; i <= n; i += 2)
if (sum + i <= n && !done)
{
sol[k] = i;
solve(n, k + 1, sum + i);
}
else
return;
}
int main()
{
solve(100000000, 0, 0);
return 0;
}
Which works great even if I add another zero!
What is the deal with Python and how can I get this working for large values as well?
The execution time for lower values is comparable with the C code, it just quits on me for higher values.
Upvotes: 4
Views: 1311
Reputation: 11174
What is the deal with Python and how can I get this working for large values as well?
I rewrote your code to make it work. You would need to adapt recursion depth when you increase the n
parameter. I used Python 2.7.6. The idea was to do it the same way as the C code you wrote, the second parameter passed would be an integer and not a list.
import sys
sys.setrecursionlimit(100000)
sol = []
stop = False
def solve(n, k, current_sum):
global stop
if stop:
return
if current_sum == n:
stop = True
for i in xrange(0, k, 1):
print(sol[i]),
print
return
start = 1 if len(sol) == 0 else (sol[k-1] + 2)
for i in xrange(start, n + 1, 2):
if current_sum + i <= n and not stop:
sol.append(0)
sol[k] = i
solve(n, k + 1, current_sum + i)
else:
return
solve(100000000, 0, 0)
I tried to read the memory usage of the python code you wrote. I had to set n = 100.000
in order to get a result of 370 MB. Adding a 0
made my operating system kill the program. (On Mac OS X I received a memory error).
Here is the code I used on Linux:
import os
import sys
sys.setrecursionlimit(100000)
_proc_status = '/proc/%d/status' % os.getpid()
_scale = {'kB': 1024.0, 'mB': 1024.0*1024.0,
'KB': 1024.0, 'MB': 1024.0*1024.0}
def _VmB(VmKey):
'''Private.
'''
global _proc_status, _scale
# get pseudo file /proc/<pid>/status
try:
t = open(_proc_status)
v = t.read()
t.close()
except:
return 0.0 # non-Linux?
# get VmKey line e.g. 'VmRSS: 9999 kB\n ...'
i = v.index(VmKey)
v = v[i:].split(None, 3) # whitespace
if len(v) < 3:
return 0.0 # invalid format?
# convert Vm value to bytes
return float(v[1]) * _scale[v[2]]
def memory(since=0.0):
'''Return memory usage in bytes.
'''
return _VmB('VmSize:') - since
stop = False
def solve(n, used, current_sum):
global stop
if stop:
return
if current_sum == n:
print(used)
stop = True
return
start = 1 if len(used) == 0 else (used[-1] + 2)
for i in range(start, n + 1, 2):
if current_sum + i <= n and not stop:
used.append(i)
solve(n, used, current_sum + i)
used.pop()
else:
return
m0 = memory()
solve(100000, [], 0)
m1 = memory(m0)
print(m1/(1024*1024))
In comparison to this result the improved (corrected) code I wrote only uses 4 MB
with the parameter n
set to 100.000.000
. That's a huge difference indeed.
I am not sure why exactly this is. In particular you have a loop that contains a recursive call (so you call recursively several times from the same branch).
If you insist on using recursive calls, then maybe you'd want to redesign your program. Recursive calls with memorization can be faster than loops in cases. See this link for example.
Upvotes: 1
Reputation: 149085
I could do some tests with your code use Python3.4 64 bits on Windows 7.
It breaks the same way :
I tried it under a FreeBSD 10.1 32 bits virtual machine with as little as 512 Mb of memory, I got a Segmentation fault little after 8000 iterations
I think it is a bug in the CPython interpreter with deep recursion. Because I added some traces and in all my tests it breaked during the initial phase of adding elements to the list before reaching the sum.
I would have accepted any error about abuse of recursion (it is indeed :-) ), but a segmentation fault is really bad : it looks like if Python itself does not control a bound.
Once we notice that the sum of the n first odd numbers is n2 (mathematicaly trivial), it is easy to immediately start near the final solution. It would be enough for this example, because 100000000 == 100002. But in the general case, with making last number of used
list vary, we step the current_sum
by 2, so we are still missing every second number. But if we go one step previous and start from there, we again step current_sum by 2, but on the other numbers.
So here is a slight variation of original code that works :
import sys
sys.setrecursionlimit(100000)
stop = False
def solve(n, used, current_sum):
global stop
if stop:
return
# TRACES
## print (len(used), used[-1] if len(used) > 0 else '', end=' ')
## if (len(used) % 10) == 0:
## print('')
if current_sum == n:
print(used)
global stop
stop = True
return
if current_sum > n: # simple optimisation, no need to go further
return
# the trick : sum of n first numbers is n*n, and we must start 2 steps before
if current_sum == 0:
import math
l = int(math.sqrt(n)) - 2
current_sum = l * l
used = list(range(1, l*2, 2))
solve(n, used, current_sum)
return
start = 1 if len(used) == 0 else (used[-1] + 2)
for i in range(start, n + 1, 2):
if current_sum + i <= n and not stop:
used.append(i)
solve(n, used, current_sum + i)
used.pop()
else:
return
solve(100000000, [], 0)
Upvotes: 1