Reputation: 55
I'm working on a program that determines the number of steps it takes for a number to become 1 using the Collatz Conjecture (if n is odd, 3n+1; if n is even, n/2). The program increases the number being calculated by one each time it completes a calculation, and tests how many numbers it can calculate in seconds. Here is the working program I currently have:
public class Collatz {
static long numSteps = 0;
public static long calculate(long c){
if(c == 1){
return numSteps;
}
else if(c % 2 == 0){
numSteps++;
calculate(c / 2);
}
else if(c % 2 != 0){
numSteps++;
calculate(c * 3 + 1);
}
return numSteps;
}
public static void main(String args[]){
int n = 1;
long startTime = System.currentTimeMillis();
while(System.currentTimeMillis() < startTime + 60000){
calculate(n);
n++;
numSteps = 0;
}
System.out.println("The highest number was: " + n);
}
}
It can currently calculate about 100 million numbers in a minute, but I'm looking for advice on how to further optimize the program so that it can calculate more numbers in a minute. Any and all advice would be appreciated :).
Upvotes: 3
Views: 2348
Reputation: 533680
You can
optimise the calculate method by assuming that is c % 2 == 0
is false than c % 2 != 0
must be true. You can also assume that c * 3 + 1
must be an even number so you can calculate (c * 3 + 1)/2
and add two to the numSteps. You can use a loop instead of recursion as Java doesn't have tail-call optimisation.
get a bigger improvement by using memorisation. For each each number you can memorise the result you get and if the number has been calculated before just return that value. You might want to place an upper bound on memorization e.g. no higher than the last number you want to calculate. If you don't do this some of the value will be many times the largest value.
For your interest
public class Collatz {
static final int[] CALC_CACHE = new int[2_000_000_000];
static int calculate(long n) {
int numSteps = 0;
long c = n;
while (c != 1) {
if (c < CALC_CACHE.length) {
int steps = CALC_CACHE[(int) c];
if (steps > 0) {
numSteps += steps;
break;
}
}
if (c % 2 == 0) {
numSteps++;
c /= 2;
} else {
numSteps += 2;
if (c > Long.MAX_VALUE / 3)
throw new IllegalStateException("c is too large " + c);
c = (c * 3 + 1) / 2;
}
}
if (n < CALC_CACHE.length) {
CALC_CACHE[(int) n] = numSteps;
}
return numSteps;
}
public static void main(String args[]) {
long n = 1, maxN = 0, maxSteps = 0;
long startTime = System.currentTimeMillis();
while (System.currentTimeMillis() < startTime + 60000) {
for (int i = 0; i < 10; i++) {
int steps = calculate(n);
if (steps > maxSteps) {
maxSteps = steps;
maxN = n;
}
n++;
}
if (n % 10000000 == 1)
System.out.printf("%,d%n", n);
}
System.out.printf("The highest number was: %,d, maxSteps: %,d for: %,d%n", n, maxSteps, maxN);
}
}
prints
The highest number was: 1,672,915,631, maxSteps: 1,000 for: 1,412,987,847
A more advanced answer would be to use multiple threads. In this case using recursion with memorisation was easier to implement.
import java.util.stream.LongStream;
public class Collatz {
static final short[] CALC_CACHE = new short[Integer.MAX_VALUE-8];
public static int calculate(long c) {
if (c == 1) {
return 0;
}
int steps;
if (c < CALC_CACHE.length) {
steps = CALC_CACHE[(int) c];
if (steps > 0)
return steps;
}
if (c % 2 == 0) {
steps = calculate(c / 2) + 1;
} else {
steps = calculate((c * 3 + 1) / 2) + 2;
}
if (c < CALC_CACHE.length) {
if (steps > Short.MAX_VALUE)
throw new AssertionError();
CALC_CACHE[(int) c] = (short) steps;
}
return steps;
}
static int calculate2(long n) {
int numSteps = 0;
long c = n;
while (c != 1) {
if (c < CALC_CACHE.length) {
int steps = CALC_CACHE[(int) c];
if (steps > 0) {
numSteps += steps;
break;
}
}
if (c % 2 == 0) {
numSteps++;
c /= 2;
} else {
numSteps += 2;
if (c > Long.MAX_VALUE / 3)
throw new IllegalStateException("c is too large " + c);
c = (c * 3 + 1) / 2;
}
}
if (n < CALC_CACHE.length) {
CALC_CACHE[(int) n] = (short) numSteps;
}
return numSteps;
}
public static void main(String args[]) {
long maxN = 0, maxSteps = 0;
long startTime = System.currentTimeMillis();
long[] res = LongStream.range(1, 6_000_000_000L).parallel().collect(
() -> new long[2],
(long[] arr, long n) -> {
int steps = calculate(n);
if (steps > arr[0]) {
arr[0] = steps;
arr[1] = n;
}
},
(a, b) -> {
if (a[0] < b[0]) {
a[0] = b[0];
a[1] = b[1];
}
});
maxN = res[1];
maxSteps = res[0];
long time = System.currentTimeMillis() - startTime;
System.out.printf("After %.3f seconds, maxSteps: %,d for: %,d%n", time / 1e3, maxSteps, maxN);
}
}
prints
After 52.461 seconds, maxSteps: 1,131 for: 4,890,328,815
Note: If I change the second calculate call to
steps = calculate((c * 3 + 1) ) + 1;
it prints
After 63.065 seconds, maxSteps: 1,131 for: 4,890,328,815
Upvotes: 2