Reputation: 151
I'm currently working on a project for school that asked me to write code for different sorting algorithms. The most difficult part was writing an iterative version of merge sort given an input array of length 2^N. I used a required helper method called merge to assist in iteratively merging.
My structure was as follows. Given an array of 2^N (let us use an array size of 16 to explain my method), I iterated through the array looking at each 2 integers, and swapping if one was greater than the other using merge(). This process would occur 8 times in a length 16 array. I would then iterate through the array looking at each 4 integers, 4 times. I would use my merge method to merge the two ordered pairs in every set of 4. Then, I would look at a block of 8 integers...so on and so forth. My code is posted here:
public static void MergeSortNonRec(long[] a) {
//======================
//FILL IN YOUR CODE HERE
//======================
/*
System.out.print("Our array is: ");
printArray(a);
System.out.println('\n');
*/
int alength = a.length;
int counter = 2;
//the counter will iterate through levels 2n - 2 4 8 16 32 etc.
int pointtracker = 0;
//the point tracker will keep track of the position in the array
while (counter <= alength) {
long [] aux = new long [alength];
int low = pointtracker;
int high = pointtracker + counter - 1;
int mid = (low + high)/2;
merge(a, aux, low, mid, high);
if (high < alength - 1) {
pointtracker += counter;
//move to the next block
}
else {
//if our high point is at the end of the array
counter *= 2;
pointtracker = 0;
//start over at a[0], with a doubled counter
}
}
/*
System.out.print("Final array is: ");
printArray(a);
System.out.println('\n');
*/
}//MergeSortNonRec()
My merge method is as follows:
private static void merge(long[] a, long[] aux, int lo, int mid, int hi) {
// copy to aux[]
for (int k = lo; k <= hi; k++) {
aux[k] = a[k];
}
// merge back to a[]
int i = lo, j = mid+1;
for (int k = lo; k <= hi; k++) {
if (i > mid) a[k] = aux[j++];
else if (j > hi) a[k] = aux[i++];
else if (aux[j] < aux[i]) a[k] = aux[j++];
else a[k] = aux[i++];
}
}
The recursive solution is much more elegant:
private static void sort(long[] a, long[] aux, int lo, int hi) {
if (hi <= lo) return;
int mid = lo + (hi - lo) / 2;
sort(a, aux, lo, mid);
sort(a, aux, mid + 1, hi);
merge(a, aux, lo, mid, hi);
}
public static void MergeSort(long[] a) {
long[] aux = new long[a.length];
sort(a, aux, 0, a.length-1);
}
My issue is with runtime. My professor has said that the iterative version of merge sort, because we are only inputting arrays at length 2^N, should run faster than the non-iterative version. However, my iterative version is running slower than the recursive version at large sets. Here is an example of my time output:
![runtime]: https://i.sstatic.net/J8Auh.jpg "sorting algorithms"
What can I do to reduce the time of my iterative mergesort?
EDIT: I've figured it out. I moved my instantiation of aux outside of the while loop and this decreased time exponentially. Thanks all!
Upvotes: 2
Views: 693
Reputation: 28911
What can I do to reduce the time of my iterative mergesort?
Wiki has a simplified example of iterative (bottom up) merge sort:
https://en.wikipedia.org/wiki/Merge_sort#Bottom-up_implementation
To reduce the time, only do a one time allocation of the aux[] array, and don't copy data on each merge pass, but instead swap the references to the arrays on each pass.
long [] t = a; // swap references
a = aux;
aux = t;
If the size of the array is an odd power of 2, you'll need to copy the array one time or swap in place instead of doing the first merge pass.
iterative merge sort should run faster than recursive merge sort
Assuming reasonably optimized versions of both, iterative merge sort will usually be faster, but the relative difference decreases as the size of the array increases, because most of the time will be spent in the merge() function which can be identical for both iterative and recursive merge sort.
There are trade offs. Recursive version will push and pop length - 2 or 2*length - 2 pairs of indices to / from the stack, while iterative generates indices on the fly (which could be kept in registers). It could seem that during the deeper levels of recursion, the recursive version is more cache friendly because it is operating on a portion of the array, while iterative version is always operating across the entire array on each pass, but I've never seen a situation where this resulted in better overall performance with recursive merge sort. Most caches on a PC are 4 or more way set associative, so two lines used for input, one line used for output during a merge process. In my testing, a multi-threaded iterative merge sort is much faster than a single-threaded iterative merge sort, so a merge sort on the systems I've tested with is not memory bandwidth limited.
Here is a somewhat optimized example of iterative (bottom up) merge sort along with a test program:
package jsortbu;
import java.util.Random;
public class jsortbu {
static void MergeSort(int[] a) // entry function
{
if(a.length < 2) // if size < 2 return
return;
int[] b = new int[a.length];
BottomUpMergeSort(a, b);
}
static void BottomUpMergeSort(int[] a, int[] b)
{
int n = a.length;
int s = 1; // run size
if(1 == (GetPassCount(n)&1)){ // if odd number of passes
for(s = 1; s < n; s += 2) // swap in place for 1st pass
if(a[s] < a[s-1]){
int t = a[s];
a[s] = a[s-1];
a[s-1] = t;
}
s = 2;
}
while(s < n){ // while not done
int ee = 0; // reset end index
while(ee < n){ // merge pairs of runs
int ll = ee; // ll = start of left run
int rr = ll+s; // rr = start of right run
if(rr >= n){ // if only left run
do // copy it
b[ll] = a[ll];
while(++ll < n);
break; // end of pass
}
ee = rr+s; // ee = end of right run
if(ee > n)
ee = n;
Merge(a, b, ll, rr, ee);
}
{ // swap references
int[] t = a;
a = b;
b = t;
}
s <<= 1; // double the run size
}
}
static void Merge(int[] a, int[] b, int ll, int rr, int ee) {
int o = ll; // b[] index
int l = ll; // a[] left index
int r = rr; // a[] right index
while(true){ // merge data
if(a[l] <= a[r]){ // if a[l] <= a[r]
b[o++] = a[l++]; // copy a[l]
if(l < rr) // if not end of left run
continue; // continue (back to while)
do // else copy rest of right run
b[o++] = a[r++];
while(r < ee);
break; // and return
} else { // else a[l] > a[r]
b[o++] = a[r++]; // copy a[r]
if(r < ee) // if not end of right run
continue; // continue (back to while)
do // else copy rest of left run
b[o++] = a[l++];
while(l < rr);
break; // and return
}
}
}
static int GetPassCount(int n) // return # passes
{
int i = 0;
for(int s = 1; s < n; s <<= 1)
i += 1;
return(i);
}
public static void main(String[] args) {
int[] a = new int[10000000];
Random r = new Random();
for(int i = 0; i < a.length; i++)
a[i] = r.nextInt();
long bgn, end;
bgn = System.currentTimeMillis();
MergeSort(a);
end = System.currentTimeMillis();
for(int i = 1; i < a.length; i++){
if(a[i-1] > a[i]){
System.out.println("failed");
break;
}
}
System.out.println("milliseconds " + (end-bgn));
}
}
Upvotes: 0