Reputation: 455
I have a list of tasks [Task-A,Task-B,Task-C,Task-D, ...]
.
One task can be optionally dependent on other tasks.
For example:
A can be dependent on 3 tasks: B, C and D
B can be dependent on 2 tasks: C and E
It's basically a directed acyclic graph and execution of a task should happen only after the dependent tasks are executed.
Now it might happen that at any point of time, there are multiple tasks that are ready for execution. In such a case, we can run them in parallel.
Any idea on how to implement such an execution while having as much parallelism as possible?
class Task{
private String name;
private List<Task> dependentTasks;
public void run(){
// business logic
}
}
Upvotes: 8
Views: 4388
Reputation: 8576
The other answer works fine but is too complicated.
A simpler way is to just execute Kahn's algorithm but in parallel.
The key is to execute all the tasks in parallel for whom all dependencies have been executed.
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
class DependencyManager {
private final ConcurrentHashMap<String, List<String>> _dependencies = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, List<String>> _reverseDependencies = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, Runnable> _tasks = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, Integer> _numDependenciesExecuted = new ConcurrentHashMap<>();
private final AtomicInteger _numTasksExecuted = new AtomicInteger(0);
private final ExecutorService _executorService = Executors.newFixedThreadPool(16);
private static Runnable getRunnable(DependencyManager dependencyManager, String taskId){
return () -> {
try {
Thread.sleep(2000); // A task takes 2 seconds to finish.
dependencyManager.taskCompleted(taskId);
} catch (InterruptedException e) {
e.printStackTrace();
}
};
}
/**
* In case a vertex is disconnected from the rest of the graph.
* @param taskId The task id
*/
public void addVertex(String taskId) {
_dependencies.putIfAbsent(taskId, new ArrayList<>());
_reverseDependencies.putIfAbsent(taskId, new ArrayList<>());
_tasks.putIfAbsent(taskId, getRunnable(this, taskId));
_numDependenciesExecuted.putIfAbsent(taskId, 0);
}
private void addEdge(String dependentTaskId, String dependeeTaskId) {
_dependencies.get(dependentTaskId).add(dependeeTaskId);
_reverseDependencies.get(dependeeTaskId).add(dependentTaskId);
}
public void addDependency(String dependentTaskId, String dependeeTaskId) {
addVertex(dependentTaskId);
addVertex(dependeeTaskId);
addEdge(dependentTaskId, dependeeTaskId);
}
private void taskCompleted(String taskId) {
System.out.println(String.format("%s:: Task %s done!!", Instant.now(), taskId));
_numTasksExecuted.incrementAndGet();
_reverseDependencies.get(taskId).forEach(nextTaskId -> {
_numDependenciesExecuted.computeIfPresent(nextTaskId, (__, currValue) -> currValue + 1);
int numDependencies = _dependencies.get(nextTaskId).size();
int numDependenciesExecuted = _numDependenciesExecuted.get(nextTaskId);
if (numDependenciesExecuted == numDependencies) {
// All dependencies have been executed, so we can submit this task to the threadpool.
_executorService.submit(_tasks.get(nextTaskId));
}
});
if (_numTasksExecuted.get() == _tasks.size()) {
topoSortCompleted();
}
}
private void topoSortCompleted() {
System.out.println("Topo sort complete!!");
_executorService.shutdownNow();
}
public void executeTopoSort() {
System.out.println(String.format("%s:: Topo sort started!!", Instant.now()));
_dependencies.forEach((taskId, dependencies) -> {
if (dependencies.isEmpty()) {
_executorService.submit(_tasks.get(taskId));
}
});
}
}
public class TestParallelTopoSort {
public static void main(String[] args) {
DependencyManager dependencyManager = new DependencyManager();
dependencyManager.addDependency("8", "5");
dependencyManager.addDependency("7", "5");
dependencyManager.addDependency("7", "6");
dependencyManager.addDependency("6", "3");
dependencyManager.addDependency("6", "4");
dependencyManager.addDependency("5", "1");
dependencyManager.addDependency("5", "2");
dependencyManager.addDependency("5", "3");
dependencyManager.addDependency("4", "1");
dependencyManager.executeTopoSort();
// Parallel version takes 8 seconds to execute.
// Serial version would have taken 16 seconds.
}
}
The Directed Acyclic Graph constructed in this example is this:
Upvotes: 9
Reputation: 8576
We can create a DAG where each vertex of the graph is one of the tasks.
After that, we can compute its topological sorted order.
We can then decorate the Task class with a priority field and run the ThreadPoolExecutor
with a PriorityBlockingQueue
which compares Tasks using the priority field.
The final trick is to override run()
to first wait for all the dependent tasks to finish.
Since each task waits indefinitely for its dependent tasks to finish, we cannot afford to let the thread-pool be completely occupied with tasks that are higher up in the topological sort order; the thread pool will get stuck forever.
To avoid this, we just have to assign priorities to tasks according to the topological order.
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
public class Testing {
private static Callable<Void> getCallable(String taskId){
return () -> {
System.out.println(String.format("Task %s result", taskId));
Thread.sleep(100);
return null;
};
}
public static void main(String[] args) throws ExecutionException, InterruptedException {
Callable<Void> taskA = getCallable("A");
Callable<Void> taskB = getCallable("B");
Callable<Void> taskC = getCallable("C");
Callable<Void> taskD = getCallable("D");
Callable<Void> taskE = getCallable("E");
PrioritizedFutureTask<Void> pfTaskA = new PrioritizedFutureTask<>(taskA);
PrioritizedFutureTask<Void> pfTaskB = new PrioritizedFutureTask<>(taskB);
PrioritizedFutureTask<Void> pfTaskC = new PrioritizedFutureTask<>(taskC);
PrioritizedFutureTask<Void> pfTaskD = new PrioritizedFutureTask<>(taskD);
PrioritizedFutureTask<Void> pfTaskE = new PrioritizedFutureTask<>(taskE);
// Create a DAG graph.
pfTaskB.addDependency(pfTaskC).addDependency(pfTaskE);
pfTaskA.addDependency(pfTaskB).addDependency(pfTaskC).addDependency(pfTaskD);
// Now that we have a graph, we can just get its topological sorted order.
List<PrioritizedFutureTask<Void>> topological_sort = new ArrayList<>();
topological_sort.add(pfTaskE);
topological_sort.add(pfTaskC);
topological_sort.add(pfTaskB);
topological_sort.add(pfTaskD);
topological_sort.add(pfTaskA);
ThreadPoolExecutor executor = new ThreadPoolExecutor(5, 5, 0L, TimeUnit.MILLISECONDS,
new PriorityBlockingQueue<Runnable>(1, new CustomRunnableComparator()));
// Its important to insert the tasks in the topological sorted order, otherwise its possible that the thread pool will be stuck forever.
for (int i = 0; i < topological_sort.size(); i++) {
PrioritizedFutureTask<Void> pfTask = topological_sort.get(i);
pfTask.setPriority(i);
// The lower the priority, the sooner it will run.
executor.execute(pfTask);
}
}
}
class PrioritizedFutureTask<T> extends FutureTask<T> implements Comparable<PrioritizedFutureTask<T>> {
private Integer _priority = 0;
private final Callable<T> callable;
private final List<PrioritizedFutureTask> _dependencies = new ArrayList<>();
;
public PrioritizedFutureTask(Callable<T> callable) {
super(callable);
this.callable = callable;
}
public PrioritizedFutureTask(Callable<T> callable, Integer priority) {
this(callable);
_priority = priority;
}
public Integer getPriority() {
return _priority;
}
public PrioritizedFutureTask<T> setPriority(Integer priority) {
_priority = priority;
return this;
}
public PrioritizedFutureTask<T> addDependency(PrioritizedFutureTask dep) {
this._dependencies.add(dep);
return this;
}
@Override
public void run() {
for (PrioritizedFutureTask dep : _dependencies) {
try {
dep.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
super.run();
}
@Override
public int compareTo(PrioritizedFutureTask<T> other) {
if (other == null) {
throw new NullPointerException();
}
return getPriority().compareTo(other.getPriority());
}
}
class CustomRunnableComparator implements Comparator<Runnable> {
@Override
public int compare(Runnable task1, Runnable task2) {
return ((PrioritizedFutureTask) task1).compareTo((PrioritizedFutureTask) task2);
}
}
Output:
Task E result
Task C result
Task B result
Task D result
Task A result
PS: Here is a well-tested and simple implementation of topological sort in Python which you can easily port in Java.
Upvotes: 2