marcotama
marcotama

Reputation: 2052

How to create a multi-threaded custom move factory in OptaPlanner?

I am solving a tough problem in OptaPlanner. The best algorithm I found so far is to use a custom move factory, a computationally intensive one. After noticing that I was utilising a single CPU core, I discovered that OptaPlanner only spreads on multiple threads the score calculation, while it performs the move generation in a single thread.

To mitigate the problem, I implemented the multi-threading in my move factory via the following abstract class, which I then extend with the actual logic (I did this because I actually have three computationally expensive custom move factories):

package my.solver.move;

import lombok.AllArgsConstructor;
import lombok.NonNull;
import org.optaplanner.core.impl.domain.solution.descriptor.SolutionDescriptor;
import org.optaplanner.core.impl.heuristic.move.CompositeMove;
import org.optaplanner.core.impl.heuristic.move.Move;
import org.optaplanner.core.impl.heuristic.selector.move.factory.MoveIteratorFactory;
import org.optaplanner.core.impl.score.director.ScoreDirector;

import java.util.Iterator;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;

public abstract class MultiThreadedMoveFactory<T> implements MoveIteratorFactory<T> {

    private final ThreadPoolExecutor threadPoolExecutor;

    public MultiThreadedMoveFactory(
            @NonNull String threadPrefix
    ) {
        int availableProcessorCount = Runtime.getRuntime().availableProcessors();
        int resolvedThreadCount = Math.max(1, availableProcessorCount);
        ThreadFactory threadFactory = new SolverThreadFactory(threadPrefix);
        threadPoolExecutor = (ThreadPoolExecutor) Executors.newFixedThreadPool(resolvedThreadCount, threadFactory);
    }

    @AllArgsConstructor
    public class MoveGeneratorData {
        T solution;
        SolutionDescriptor<T> solutionDescriptor;
        Random random;
        BlockingQueue<Move<T>> generatedMoves;
    }

    protected abstract int getNumMoves();

    @Override
    public long getSize(ScoreDirector<T> scoreDirector) {
        return getNumMoves();
    }

    protected class MovesIterator implements Iterator<Move<T>> {

        private final BlockingQueue<Move<T>> generatedMoves = new ArrayBlockingQueue<>(getNumMoves());

        public MovesIterator(
                @NonNull T solution,
                @NonNull SolutionDescriptor<T> solutionDescriptor,
                @NonNull Random random,
                @NonNull Function<MoveGeneratorData, Runnable> moveGeneratorFactory
        ) {
            MoveGeneratorData moveGeneratorData = new MoveGeneratorData(solution, solutionDescriptor, random, generatedMoves);
            for (int i = 0; i < getNumMoves(); i++) {
                threadPoolExecutor.submit(moveGeneratorFactory.apply(moveGeneratorData));
            }
        }

        @Override
        public boolean hasNext() {
            if (!generatedMoves.isEmpty()) {
                return true;
            }
            while (threadPoolExecutor.getActiveCount() > 0) {
                try {
                    //noinspection BusyWait
                    Thread.sleep(50);
                } catch (InterruptedException e) {
                    return false;
                }
            }
            return !generatedMoves.isEmpty();
        }

        @Override
        public Move<T> next() {
            //noinspection unchecked
            return Objects.requireNonNullElseGet(generatedMoves.poll(), CompositeMove::new);
        }
    }


    private static final AtomicInteger poolNumber = new AtomicInteger(1);

    private static class SolverThreadFactory implements ThreadFactory {

        private final ThreadGroup group;
        private final AtomicInteger threadNumber = new AtomicInteger(1);
        private final String namePrefix;

        public SolverThreadFactory(String threadPrefix) {
            SecurityManager s = System.getSecurityManager();
            group = (s != null) ? s.getThreadGroup() : Thread.currentThread().getThreadGroup();
            namePrefix = "MyPool-" + poolNumber.getAndIncrement() + "-" + threadPrefix + "-";
        }

        @Override
        public Thread newThread(@NonNull Runnable r) {
            Thread t = new Thread(group, r, namePrefix + threadNumber.getAndIncrement(), 0);
            if (t.isDaemon()) {
                t.setDaemon(false);
            }
            if (t.getPriority() != Thread.NORM_PRIORITY) {
                t.setPriority(Thread.NORM_PRIORITY);
            }
            return t;
        }

    }

    @Override
    public Iterator<? extends Move<T>> createOriginalMoveIterator(ScoreDirector<T> scoreDirector) {
        return createMoveIterator(scoreDirector, new Random());
    }

    @Override
    public Iterator<? extends Move<T>> createRandomMoveIterator(ScoreDirector<T> scoreDirector, Random workingRandom) {
        return createMoveIterator(scoreDirector, workingRandom);
    }

    public abstract Iterator<? extends Move<T>> createMoveIterator(ScoreDirector<T> scoreDirector, Random random);
}

However, the solver seems to hang after a while. The debugger tells me that it's waiting on an innerQueue.take() in OrderByMoveIndexBlockingQueue. This is caused by my move factory: if I revert the above and only use the previous implementation, which was single-threaded, the problem goes away.

I do not quite understand where the problem is, so the question is: how can I fix it?

Upvotes: 0

Views: 374

Answers (2)

marcotama
marcotama

Reputation: 2052

I was able to make the factory work by removing any trace of JIT-ing from hasNext: block the method until all moves have been generated, and only then return true, and keep returning true until all moves have been consumed.

        @Override
        public boolean hasNext() {
            while (!generationComplete && generatedMoves.size() < getNumMoves()) {
                try {
                    // We get a warning because the event we are waiting for could happen earlier than the end of sleep
                    // and that means we would be wasting time, but that is negligible so we silence it
                    //noinspection BusyWait
                    Thread.sleep(50);
                } catch (InterruptedException e) {
                    return false;
                }
            }
            generationComplete = true;
            return !generatedMoves.isEmpty();
        }

To the best of my understanding, the solution I am using not only works, but it is the best I found in a few months of iterations.

Upvotes: 0

Geoffrey De Smet
Geoffrey De Smet

Reputation: 27312

No, no, no. This approach is doomed. I think. (Prove me wrong.)

JIT selection

First learn about Just In Time selection (see docs) of moves. Instead of generating all moves (which can be billions) at the beginning of each step, only generate those that will actually be evaluated. Most LS algorithms will only evaluate a few moves per step.

jit selection

Watch the TRACE log to see how many milliseconds it takes to start a step. Typically you want to do evaluate 10000 moves per second, so it should take 0 or 1 milliseconds to start a step (the log only shows in milliseconds).

Multithreaded solving

Then learn about moveThreadCount to enable multithreaded solving. See this blog post. Know that this still does the move selection on 1 thread, for reproducibility reasons. But the move evaluation is spread across threads.

Caching for move selection

But your custom moves are smart, so the move selection must be smart? First determine what "solution state" query information you need to generate the moves - for example a Map<Employee, List<Shift>> - then cache that:

  • either calculate that map at the beginning of each step, if it doesn't take too long (but this won't scale because it doesn't do deltas)
  • or use a shadow variable (@InverseRelationShadowVariable works fine in this case), because these are updated through deltas. But it does do the delta's for every move and undo move too...
  • Or hack in an actual new MoveSelector, which can listen to stepEnded() events and actually apply the delta of the last step on that Map, without doing any of the deltas of every move and undo move. We should probably standardize this approach and make it part of our public API some day.

Upvotes: 1

Related Questions