Evgenij Reznik
Evgenij Reznik

Reputation: 18614

Fork/Join: Collecting results

I'm playing around with fork/join and thought of the following example:

App1: 2 for-loops generating some random numbers into an ArrayList and passing it to a fork MyTask (Fork): Iterating through the ArrayLists and adding up all numbers, then return the value

import java.util.ArrayList;
import java.util.concurrent.ForkJoinPool;

public class App1 {

    ArrayList list = new ArrayList();

    static final ForkJoinPool mainPool = new ForkJoinPool();

    public App1() {
        for (int i = 0; i < 10; i++) {
            list.clear();
            for (int j = 1000; j <= 100000; j++) {
                int random = 1 + (int)(Math.random() * ((100 - 1) + 1));
                list.add(random);
            }
            mainPool.invoke(new MyTask(list));
        }
        // At the end showing all results
        // System.out.println (result1 + result2 + result3...);
    }

    public static void main(String[] args) {
        App1 app = new App1();
    }
}


import java.util.ArrayList;
import java.util.concurrent.RecursiveTask;

public class MyTask extends RecursiveTask<Integer> {

    ArrayList list = new ArrayList();
    int result;

    public MyTask(ArrayList list) {
        this.list = list;
    }

    @Override
    protected Integer compute() {
        for(int i=0; i<=list.size(); i++){
            result += (int)list.get(i); // adding up all numbers
        }
        return result;
    }
}

I'm not sure if I'm on the right track. I also don't know, how to collect all the results from the forks.
Could anybody please have a look at my code?

Upvotes: 3

Views: 4302

Answers (4)

Ortwin Angermeier
Ortwin Angermeier

Reputation: 6213

You are using the RecursiveTask not the right way. You have to call the task recursive, take a look at a possible solution below.

So in the end all work is splitted in portions that are < THRESHOLD, recursively.

Update:

You might want to take a look at Java Concurrent Animated, just download the jar and execute it, it has a nice visual explanation how things work.

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

public class RecursiveListSum {

  private static class RecursiveSum extends RecursiveTask<Long> {
    private static final long serialVersionUID = 1L;
    private static final int THRESHOLD = 1000;
    private final List<Integer> list;
    private final int begin;
    private final int end;

    public RecursiveSum(List<Integer> list, int begin, int end) {
      super();
      this.list = list;
      this.begin = begin;
      this.end = end;
    }

    @Override
    protected Long compute() {
      final int size = end - begin;
      // if the work to be done is below some threshold, just compute directly.
      if (size < THRESHOLD) {
        long sum = 0;
        for (int i = begin; i < end; i++)
          sum += list.get(i);
        return sum;
      } else {
        // split the work to other tasks - recursive (that's why it is called recursive task!)
        final int middle = begin + ((end - begin) / 2);
        RecursiveSum sum1 = new RecursiveSum(list, begin, middle);
        // invoke the first portion -> will be invoked in thread pool
        sum1.fork();
        RecursiveSum sum2 = new RecursiveSum(list, middle, end);
        // now do a blocking! compute on the second task and wait for the result of the first task.
        return sum2.compute() + sum1.join();
      }
    }
  }

  public static void main(String[] args) {
    // First fill the list
    List<Integer> list = new ArrayList<>();
    long expectedSum = 0;
    for (int i = 0; i < 10000; i++) {
      int random = 1 + (int) (Math.random() * ((100 - 1) + 1));
      list.add(random);
      expectedSum += random;
    }
    System.out.println("expected sum: " + expectedSum);

    // now let the RecursiveTask calc the sum again.
    final ForkJoinPool forkJoinPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors());
    final RecursiveSum recursiveSum = new RecursiveSum(list, 0, list.size());
    long recSum = forkJoinPool.invoke(recursiveSum);
    System.out.println("recursive-sum: " + recSum);
  }

}

Upvotes: 0

TwoThe
TwoThe

Reputation: 14309

While your code looks fine, you are using ForkJoinPool in a wrong way. This is a tool for tasks that can be split into independent sub-tasks and can be increased in speed through multi-threading.

Your task probably isn't big enough to really benefit from multi-threading, but putting that aside since it is a learning exercise, you still need to split the main task into sub-tasks, but all you do is to count the entire array once.

Things you could do in your code:

  • Use a different mode of multi-threading that fits better.

  • Fork the task and hand over a start and end position inside the array, then sum up the results once those sub-tasks are done.

Since you are probably interested in the later one, here is an example on how to do this:

public class MyTask extends RecursiveTask<Integer> {

  final ArrayList<Integer> list;
  final int start, end;

  public MyTask(ArrayList<Integer> list, int start, int end) {
    this.list = list;
    this.start = start;
    this.end = end;
  }

  @Override
  protected Integer compute() {
    if (end - start > 10) { // is this task big enough to justify more threading?
      final int half = (end + start) / 2;
      final MyTask firstHalf = new MyTask(list, start, half);
      final MyTask secondHalf = new MyTask(list, half+1, end);
      invokeAll(firstHalf, secondHalf);
      return firstHalf.get() + secondHalf.get();
    } else {
      int result = 0;
      for(int i=start; i<=end; i++){
        result += list.get(i); 
      }
      return result;
    }
  }
}

Upvotes: 1

Trying
Trying

Reputation: 14278

You need to extend RecursiveAction or RecursiveTask according to you need whether you need to return something or not.

In both the classes you need to override the following function:

protected abstract void compute();//in RecursiveAction
protected abstract V compute(); //in RecursiveTask

Below is the modified quick sort. Please observe this and try implementing your self:

public class ForkJoinQuicksortTask extends RecursiveAction {
    static final int SEQUENTIAL_THRESHOLD = 10000;

    private final int[] a;
    private final int left;
    private final int right;

    public ForkJoinQuicksortTask(int[] a) {
        this(a, 0, a.length - 1);
    }

    private ForkJoinQuicksortTask(int[] a, int left, int right) {
        this.a = a;
        this.left = left;
        this.right = right;
    }

    @Override
    protected void compute() {
        if (serialThresholdMet()) {
            Arrays.sort(a, left, right + 1);
        } else {
            int pivotIndex = partition(a, left, right);
            ForkJoinQuicksortTask  t1 = new ForkJoinQuicksortTask(a, left, pivotIndex-1);
            ForkJoinQuicksortTask t2 = new ForkJoinQuicksortTask(a, pivotIndex + 1, right);
            t1.fork();
            t2.compute();
            t1.join();
        }
    }
    int partition(int[] a, int p, int r){
        int i=p-1;
        int x=a[r];
        for(int j=p;j<r;j++){
            if(a[j]<x){
                i++;
                swap(a, i, j);
            }
        }
        i++;
        swap(a, i, r);
        return i;
    }

    void swap(int[] a, int p, int r){
        int t=a[p];
        a[p]=a[r];
        a[r]=t;
    }

    private boolean serialThresholdMet() {
        return right - left < SEQUENTIAL_THRESHOLD;
    }
    public static void main(String[] args){
        ForkJoinPool fjPool = new ForkJoinPool();
        int[] a=new int[3333344];
        for(int i=0;i<a.length;i++){
            int k=(int)(Math.random()*22222);
            a[i]=k;
        }
        ForkJoinQuicksortTask forkJoinQuicksortTask=new ForkJoinQuicksortTask(a, 0, a.length-1);
        long start=System.nanoTime();
        fjPool.invoke(forkJoinQuicksortTask);
        System.out.println("Time: "+ (System.nanoTime()-start));
    }
}

Upvotes: 0

user1907906
user1907906

Reputation:

Just use the result of invoke:

Integer result = mainPool.invoke(new MyTask(list));
System.out.println(i + "\t" + result);

Upvotes: 0

Related Questions