Ka Mal
Ka Mal

Reputation: 145

K-way merge operation for merge sort

I have k sorted arrays, each with n elements, and need to combine them into a single sorted array of k*n elements.

How do I implement the merging procedure for merge sort, starting with the first two and the next one and so on?

This is what I have so far.

// implementing function to merge arrays (merge procedure for merge sort)   
    public  int[] merge(int[][] array){
        int k = array.length;
        int n = array[0].length;

// final merged array
         int[] mergedArray = new int[k*n];
        return mergedArray;     
    }

    public static void main(String[]args){
    Merge obj = new Merge();
int[][] data= new int[][]{{2, 9, 15, 20},
                              {6, 8, 9, 19},
                              {5, 10, 18, 22},
                              {8, 12, 15, 26}};
    int[] mergedArrayTest = obj.merge(data);
    //printArray(mergedArrayTest);
  }

Upvotes: 2

Views: 2953

Answers (1)

Michael Laszlo
Michael Laszlo

Reputation: 12239

Instead of merging the sub-arrays two at a time, you can merge all k at once.

  • Make an array of indices into each sub-array. Initially each index is zero.
  • On each one of k*n iterations to fill the merged array, consider each sub-array's value at its respective index and remember the minimum value. (Skip an index if it has already reached the end of the sub-array.)
  • Increment the index that pointed to the minimum value.

This will do it:

// k-way merge operation
public int[] merge(int[][] array){
  int k = array.length;
  int n = array[0].length;

  int[] mergedArray = new int[k*n];
  int[] indices = new int[k];
  for (int i = 0; i < mergedArray.length; ++i) { 
    int bestValue = -1, bestIndex = -1;
    for (int j = 0; j < indices.length; ++j) { 
      int index = indices[j];
      if (index < n && (bestValue == -1 || array[j][index] < bestValue)) { 
        bestValue = array[j][index];
        bestIndex = j;
      } 
    } 
    mergedArray[i] = bestValue;
    indices[bestIndex] += 1;
  }

  return mergedArray;
}

You can make this approach somewhat more efficient by removing indices that have reached the end of their sub-array. However, that still leaves the running time in O(nk2) because O(k) indices are scanned nk times.

We can make an asymptotic improvement in running time by storing the indices in a min-heap that uses the value at each index as the key. With k indices, the size of the heap never exceeds k. In each of nk iterations, we pop the heap and push at most one element back on. These heap operations each cost O(log k), so the total running time is O(nk log k).

import java.lang.*;
import java.util.*;
import java.io.*;

class Candidate {
  int id, index, value;
  Candidate(int id, int index, int value) {
    this.id = id;
    this.index = index;
    this.value = value;
  }
}

class Heap {
  ArrayList<Candidate> stack = new ArrayList<Candidate>();

  void push(Candidate current) {
    // Add to last position in heap.
    stack.add(current);
    // Bubble up.
    int n = stack.size(),
        pos = n - 1;
    while (pos != 0) {
      int parent = (pos - 1) / 2;
      if (stack.get(parent).value <= current.value) {
        return;
      }
      stack.set(pos, stack.get(parent));
      stack.set(parent, current);
    }
  }

  Candidate pop() {
    // Get top of heap.
    if (stack.size() == 0) {
      return null;
    }
    Candidate result = stack.get(0);
    int n = stack.size();
    if (n == 1) {
      stack.remove(0);
      return result;
    }
    // Swap last element to top.
    stack.set(0, stack.get(--n));
    Candidate current = stack.get(0);
    stack.remove(n);
    // Bubble down.
    int pos = 0;
    while (true) {
      int left = 2 * pos + 1;
      if (left >= n) {
        return result;
      }
      int right = left + 1,
          swapTo = -1;
      if (current.value <= stack.get(left).value) {
        if (right == n || current.value <= stack.get(right).value) {
          return result;
        }
        swapTo = right;
      } else {
        if (right != n && stack.get(left).value > stack.get(right).value) {
          swapTo = right;
        } else {
          swapTo = left;
        }
      }
      stack.set(pos, stack.get(swapTo));
      stack.set(swapTo, current);
      pos = swapTo;
    }
  }
}

public class Merge {

  // k-way merge
  public  int[] merge(int[][] array){
    int k = array.length;
    int n = array[0].length;

    int[] mergedArray = new int[k*n];

    // Initialize heap with subarray number, index, and value.
    Heap indexHeap = new Heap();
    for (int i = 0; i < k; ++i) {
      indexHeap.push(new Candidate(i, 0, array[i][0]));
    }

    for (int i = 0; i < mergedArray.length; ++i) {
      // Get the minimum value from the heap and augment the merged array.
      Candidate best = indexHeap.pop();
      mergedArray[i] = best.value;
      // Increment the index. If it's still valid, push it back onto the heap.
      if (++best.index < array[best.id].length) {
        best.value = array[best.id][best.index];
        indexHeap.push(best);
      }
    }

    // Print out the merged array for testing purposes.
    for (int i = 0; i < mergedArray.length; ++i) {
      System.out.print(mergedArray[i] + " ");
    }
    System.out.println();
    return mergedArray;
  }

  public static void main(String[]args){
    Merge merge = new Merge();
    int[][] data= new int[][]{{2, 9, 15, 20},
                              {6, 8, 9, 19},
                              {5, 10, 18, 22},
                              {8, 12, 15, 26}};
    int[] mergedArrayTest = merge.merge(data);
  }
}

Upvotes: 4

Related Questions