Julian Carrier
Julian Carrier

Reputation: 90

Multi threaded quick sort much slower than expected

Basically I've spent several hours researching the best way to implement both a recursive and multi-threaded quick-sort and merge sort (this post is only about quick-sort). My goal in this is to take every logical processor in a given computer and pin them all for maximum quick-sort speed.

I took the approach of dividing up my problem recursively while creating threads until either the array was sorted or I hit the amount of processors on the cpu, in which case the rest of the problem would not be divided onto new threads but the remainder executed on their own core.

After creating a very rudimentary solution which could only work on my computer I ran into the Fork/Join framework which I tried to use below but I have literally no idea how. What I came up with was slower at sorting 10000000 random numbers ranging from 0 - 1000 than its single threaded counterpart, but I still think its interesting because in its docs it says, its able to steal work from slower threads whatever that means.

Then I just recently heard about thread pools and creating all of my threads initially and handing them out because creating new threads is taxing on the system. But I never got as far as trying to implement this. Perhaps my understanding of Fork/Join is skewed and I was wondering if anyone can point me in the right direction or tell me what I'm doing wrong in my current program.

Below you'll find my attempt at a multi threaded quick sort and a single threaded quick sort which is what I'm trying to translate to my multi threaded one. Any help is appreciated. Cheers!.

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;


public class MultithreadedQuicksort {
    public static void main(String[] args) {
         List<Comparable> nums = new ArrayList<Comparable>();
         Random rand = new Random();
         for (int i=0; i<10000000; i++) {
            nums.add(rand.nextInt(1000));
         }

         long start = System.currentTimeMillis();
         Quicksort quickSort = new Quicksort(nums, 0, nums.size() -1);
         ForkJoinPool pool = new ForkJoinPool();
         pool.invoke(quickSort);
         long end = System.currentTimeMillis();
         System.out.println(end - start);
         System.out.println(nums.size());
    }
}

class Quicksort extends RecursiveAction {
    int first;
    int last;
    private List<Comparable> nums;
    Comparable midValue;
    int midIndex;
    int low;
    int high;

    public Quicksort(List<Comparable> nums){
        this.nums=nums;
        this.low = 0;
        this.high = nums.size() - 1;
    }

    public Quicksort(List<Comparable> nums, int first, int last) {
        this.first = first;
        this.last = last;
        this.nums = nums;
        this.low = first;
        this.high = last;
        this.midIndex = (first + last) / 2;
        this.midValue = nums.get(midIndex);
    }


    @Override
    protected void compute() {
        split();
        if (high > first)
            invokeAll(new Quicksort(nums, first, high));
        if (low < last)
            invokeAll(new Quicksort(nums, low, last));
    }

    public void split() {
        while(low < high) {
            while (nums.get(low).compareTo(midValue) < 0) {
                  low++;
            }
            while (nums.get(high).compareTo(midValue) > 0) {
                  high--;
            }
            if (low <= high) {
                swap(low, high);
                low++;
                high--;
            }
        }
    }

    public void swap(int index1, int index2)
    {
        Comparable temp;
        temp = nums.get(index1);
        nums.set(index1, nums.get(index2));
        nums.set(index2, temp);
    }
}

Single Threaded

public static void quickSort(List<Comparable> nums, int first, int last) {
    int low = first;
    int high = last;
    int midIndex = (first + last) / 2;
    Comparable midValue = nums.get(midIndex);

    while(low < high) {
        while (nums.get(low).compareTo(midValue) < 0) {
              low++;
        }
        while (nums.get(high).compareTo(midValue) > 0) {
              high--;
        }
        if (low <= high) {
            swap(nums, low, high);
            low++;
            high--;
        }
    }
    if (high > first)
           quickSort(nums, first, high);
    if (low < last)
           quickSort(nums, low, last);
    }

Upvotes: 1

Views: 437

Answers (1)

rcgldr
rcgldr

Reputation: 28911

I don't know java that well, so the example code below may be an awkward usage of runnable for the multiple threads. This example code uses 8 threads, qsortmt() does a partition and starts two instances of qsort0(). Each instance of qsort0() does a partition and invokes two instances of qsort1(). Each instance of qsort1() does a partition and invokes two instances of qsort2(). Each instance of qsort2() calls qsort(). For the 16 million integers used in this example, the 8 threaded sort takes about 1 second, while a non-threaded sort takes about 1.6 seconds, so not a huge savings. Part of the issue is the partition steps are done before invoking threads to operate operate on the sub-partitions.

Switching to C++ and Windows native threads, 8 threads took about 0.632 seconds, non-threaded about 1.352 seconds. Switching to merge sort, splitting the array into 8 parts, sorting each part, then merging the 8 parts took about 0.40 seconds, single threaded about 1.45 seconds.

package x;
import java.util.Random;

public class x {

    class qsort0 implements Runnable
    {
        int[] a;
        int lo;
        int hi;

        private qsort0(int[] a, int lo, int hi)
        {
            this.a = a;
            this.lo = lo;
            this.hi = hi;
        }
        @Override
        public void run()
        {
            if(this.lo >= this.hi)
                return;
            int pi = partition(this.a, this.lo, this.hi);
            Thread lt = new Thread(new qsort1(a, this.lo, pi));
            Thread rt = new Thread(new qsort1(a, pi+1, this.hi));
            lt.start();
            rt.start();
            try {lt.join();} catch (InterruptedException ex){}
            try {rt.join();} catch (InterruptedException ex){}
        }
    }

    class qsort1 implements Runnable
    {
        int[] a;
        int lo;
        int hi;

        private qsort1(int[] a, int lo, int hi)
        {
            this.a = a;
            this.lo = lo;
            this.hi = hi;
        }
        @Override
        public void run()
        {
            if(this.lo >= this.hi)
                return;
            int pi = partition(this.a, this.lo, this.hi);
            Thread lt = new Thread(new qsort2(a, this.lo, pi));
            Thread rt = new Thread(new qsort2(a, pi+1, this.hi));
            lt.start();
            rt.start();
            try {lt.join();} catch (InterruptedException ex){}
            try {rt.join();} catch (InterruptedException ex){}
        }
    }

    class qsort2 implements Runnable
    {
        int[] a;
        int lo;
        int hi;
        private qsort2(int[] a, int lo, int hi)
        {
            this.a = a;
            this.lo = lo;
            this.hi = hi;
        }
        @Override
        public void run() {
            if(this.lo >= this.hi)
                return;
            qsort(this.a, this.lo, this.hi);
        }
    }

    // quicksort multi-threaded
    @SuppressWarnings("empty-statement")
    public static void qsortmt(int[] a, int lo, int hi)
    {
        if(lo >= hi)
            return;
        int pi = partition(a, lo, hi);
        Thread lt = new Thread(new x().new qsort0(a, lo, pi));
        Thread rt = new Thread(new x().new qsort0(a, pi+1, hi));
        lt.start();
        rt.start();
        try {lt.join();} catch (InterruptedException ex){}
        try {rt.join();} catch (InterruptedException ex){}
    }

    @SuppressWarnings("empty-statement")
    public static int partition(int []a, int lo, int hi)
    {
        int  md = lo+(hi-lo)/2;
        int  ll = lo-1;
        int  hh = hi+1;
        int t;
        int p = a[md];
        while(true){
            while(a[++ll] < p);
            while(a[--hh] > p);
            if(ll >= hh)
                break;
            t     = a[ll];
            a[ll] = a[hh];
            a[hh] = t;
        }
        return hh;
    }

    @SuppressWarnings("empty-statement")
    public static void qsort(int[] a, int lo, int hi)
    {
        while(lo < hi){
            int ll = partition(a, lo, hi);
            int hh = ll+1;
            // recurse on smaller part, loop on larger part
            if((ll - lo) <= (hi - hh)){
                qsort(a, lo, ll);
                lo = hh;
            } else {
                qsort(a, hh, hi);
                hi = ll;
            }
        }
    }

    public static void main(String[] args)
    {
        int[] a = new int[16*1024*1024];
        Random r = new Random(0);
        for(int i = 0; i < a.length; i++)
            a[i] = r.nextInt();
        long bgn, end;
        bgn = System.currentTimeMillis();
        qsortmt(a, 0, a.length-1);
        end = System.currentTimeMillis();
        for(int i = 1; i < a.length; i++){
            if(a[i-1] > a[i]){
                System.out.println("failed");
                break;
            }
        }
        System.out.println("milliseconds " + (end-bgn));
    }
}

Upvotes: 0

Related Questions