Gili
Gili

Reputation: 90101

Non-blocking rate-limited ThreadPoolExecutor

I am hitting an HTTP server concurrently with multiple connections. I'd like to throttle the clients in response to servers indicating that requests are coming in too fast. I do not wish to change the HTTP library I am using but rather I'd like to extend it.

To that end, how do I implement a ThreadPoolExecutor with the following constraints?

What I've Looked Into

Upvotes: 9

Views: 3508

Answers (2)

Mykhaylo Adamovych
Mykhaylo Adamovych

Reputation: 20966

Usage:

public static void main(String[] args) throws InterruptedException {
    RateFriendlyThreadPoolExecutor executor = new RateFriendlyThreadPoolExecutor(3, 5, 1, SECONDS, new LinkedBlockingDeque<>(100));
    executor.setRate(10);
    executor.setMinRate(1);
    executor.setMaxRate(100);

    for (int i = 0; i < 1000; i++) {
        int lap = i;
        executor.execute(() -> System.out.printf("%03d (%s) - %s - %s%n", lap, executor.getRate(), LocalDateTime.now(), Thread.currentThread().getName()));
    }

    executor.shutdown();
    executor.awaitTermination(60, SECONDS);
}

Output:

002 (10) - 2023-05-27T23:03:37.659658800 - pool-1-thread-3
000 (11) - 2023-05-27T23:03:37.659658800 - pool-1-thread-1
001 (11) - 2023-05-27T23:03:37.744859100 - pool-1-thread-2
105 (11) - 2023-05-27T23:03:37.930152500 - main
103 (12) - 2023-05-27T23:03:38.037876400 - pool-1-thread-4
104 (12) - 2023-05-27T23:03:38.130058800 - pool-1-thread-5
003 (12) - 2023-05-27T23:03:38.221655300 - pool-1-thread-3
004 (12) - 2023-05-27T23:03:38.314020700 - pool-1-thread-1
005 (12) - 2023-05-27T23:03:38.406202700 - pool-1-thread-2
006 (12) - 2023-05-27T23:03:38.573508200 - pool-1-thread-4
007 (13) - 2023-05-27T23:03:38.665875900 - pool-1-thread-5
008 (13) - 2023-05-27T23:03:38.742695200 - pool-1-thread-3

Implementation:

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantLock;

import static java.lang.Math.min;
import static java.lang.System.nanoTime;

public class RateFriendlyThreadPoolExecutor extends ThreadPoolExecutor {

    private AtomicInteger rate = new AtomicInteger();
    private AtomicInteger minRate = new AtomicInteger();
    private AtomicInteger maxRate = new AtomicInteger();
    private AtomicLong leapTime = new AtomicLong();
    private ReentrantLock rateLock = new ReentrantLock();

    public RateFriendlyThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, (r, e) -> overflow(r, (RateFriendlyThreadPoolExecutor) e));
    }

    @Override
    public void execute(Runnable rannable) {
        super.execute(() -> executeWithDelay(rannable));
    }

    protected void executeWithDelay(Runnable runnable) {
        int rateSnapshot = rate.get();
        limitRate(rateSnapshot, leapTime, rateLock);
        try {

            runnable.run();

            rate.compareAndSet(rateSnapshot, min(rateSnapshot + 1, maxRate.get()));
        } catch (Exception e) {
            if (!isThrottled(e))
                throw e;
            System.out.println("throttled at rate " + rateSnapshot);
            rate.set(minRate.get());
            execute(runnable);
        }
    }

    // works for parallel streams like a charm
    public static void limitRate(int rate, AtomicLong leapTime, ReentrantLock rateLock) {
        if (rate == 0)
            return;
        long targetLeapTime = 1_000_000_000 / rate;
        rateLock.lock();
        try {
            long timeSnapshot = nanoTime();
            long waitTime = targetLeapTime - (timeSnapshot - leapTime.get());
            if (waitTime > 0) {

                if (waitTime > 1_000_000)
                    LockSupport.parkNanos(waitTime);
                else
                    while (timeSnapshot + waitTime > nanoTime())
                        /* busy wait */;

                leapTime.set(timeSnapshot + waitTime);
            } else {
                leapTime.set(timeSnapshot);
            }
        } finally {
            rateLock.unlock();
        }
    }

    private static void overflow(Runnable r, RateFriendlyThreadPoolExecutor e) {
        if (!e.isShutdown())
            e.executeWithDelay(r);
    }

    private boolean isThrottled(Exception e) {
        return e.getMessage().contains("Reduce your rate");
    }

    public AtomicInteger getRate() {
        return rate;
    }

    public void setRate(int rate) {
        this.rate.set(rate);
        minRate.compareAndSet(0, rate);
        maxRate.compareAndSet(0, rate);
    }

    public AtomicInteger getMinRate() {
        return minRate;
    }

    public void setMinRate(int minRate) {
        this.minRate.set(minRate);
    }

    public AtomicInteger getMaxRate() {
        return maxRate;
    }

    public void setMaxRate(int maxRate) {
        this.maxRate.set(maxRate);
    }
}

Upvotes: 0

Gili
Gili

Reputation: 90101

Answering my own question:

  • It isn't possible to have a solution that is completely non-blocking. Even ScheduledThreadPoolExecutor keeps at least one thread around waiting for the queue to return a new task.
  • ThreadPoolExecutor sits on top of a BlockingQueue. When there are no tasks left, it blocks on BlockingQueue.take()
  • The solution has 3 moving pieces:
  1. A rate limiter.
  2. A BlockingQueue that hides elements until the rate limiter allows their consumption.
  3. A ThreadPoolExecutor that sits on top of the BlockingQueue.

The Rate Limiter

I provide my own rate limiter based on the Token Bucket algorithm algorithm to overcome RateLimiter's limitations. The source-code can be found here.


The BlockingQueue

I implementing a BlockingDeque (which extends BlockingQueue) because in the future I want to try pushing failed tasks back to the front of the queue.

RateLimitedBlockingDeque.java

import java.time.Duration;
import java.util.Collection;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import static org.bitbucket.cowwoc.requirements.core.Requirements.requireThat;

/**
 * A blocking deque of elements, in which an element can only be taken when the deque-wide delay has expired.
 * <p>
 * The optional capacity bound constructor argument serves as a way to prevent excessive expansion. The capacity, if
 * unspecified, is equal to {@link Integer#MAX_VALUE}.
 * <p>
 * Even though methods that take elements, such as {@code take} or {@code poll}, respect the deque-wide delay the
 * remaining methods treat them as normal elements. For example, the {@code size} method returns the count of both
 * expired and unexpired elements.
 * <p>
 * This class and its iterator implement all of the <em>optional</em> methods of the {@link Collection} and {@link
 * Iterator} interfaces.
 *
 * @param <E> the type of elements in the deque
 * @author Gili Tzabari
 */
public final class RateLimitedBlockingDeque<E> implements BlockingDeque<E>
{
    private final int capacity;
    private final LinkedBlockingDeque<E> delegate;
    private final Bucket rateLimit = new Bucket();

    /**
     * Creates a {@code RateLimitedBlockingDeque} with a capacity of {@link Integer#MAX_VALUE}.
     */
    public RateLimitedBlockingDeque()
    {
        this.capacity = Integer.MAX_VALUE;
        this.delegate = new LinkedBlockingDeque<>();
    }

    /**
     * Creates a {@code RateLimitedBlockingDeque} with the given (fixed) capacity.
     *
     * @param capacity the capacity of this deque
     * @throws IllegalArgumentException if {@code capacity} is less than 1
     */
    public RateLimitedBlockingDeque(int capacity)
    {
        this.capacity = capacity;
        this.delegate = new LinkedBlockingDeque<>(capacity);
    }

    /**
     * @return the capacity of the deque
     */
    public int getCapacity()
    {
        return capacity;
    }

    /**
     * Indicates the rate at which elements may be taken from the queue.
     *
     * @param elements the number of elements that may be taken per {@code period}
     * @param period   indicates how often elements may be taken
     * @throws NullPointerException     if {@code period} is null
     * @throws IllegalArgumentException if the requested rate is greater than element per nanosecond
     */
    public void setRate(long elements, Duration period)
    {
        synchronized (rateLimit)
        {
            Limit newLimit = new Limit(elements, period, 0, Long.MAX_VALUE);
            if (rateLimit.getLimits().isEmpty())
                rateLimit.addLimit(newLimit);
            else
            {
                Limit oldLimit = rateLimit.getLimits().iterator().next();
                rateLimit.replaceLimit(oldLimit, newLimit);
            }
        }
    }

    /**
     * Allows consumption of elements without limit.
     */
    public void removeRate()
    {
        synchronized (rateLimit)
        {
            rateLimit.removeAllLimits();
        }
    }

    @Override
    public void addFirst(E e)
    {
        delegate.addFirst(e);
    }

    @Override
    public void addLast(E e)
    {
        delegate.addLast(e);
    }

    @Override
    public boolean offerFirst(E e)
    {
        return delegate.offerFirst(e);
    }

    @Override
    public boolean offerLast(E e)
    {
        return delegate.offerLast(e);
    }

    @Override
    public void putFirst(E e) throws InterruptedException
    {
        delegate.putFirst(e);
    }

    @Override
    public void putLast(E e) throws InterruptedException
    {
        delegate.putLast(e);
    }

    @Override
    public boolean offerFirst(E e, long timeout, TimeUnit unit) throws InterruptedException
    {
        return delegate.offerFirst(e, timeout, unit);
    }

    @Override
    public boolean offerLast(E e, long timeout, TimeUnit unit) throws InterruptedException
    {
        return delegate.offerLast(e, timeout, unit);
    }

    @Override
    public E removeFirst()
    {
        if (rateLimit.tryConsume())
            return delegate.removeFirst();
        throw new NoSuchElementException();
    }

    @Override
    public E removeLast()
    {
        if (rateLimit.tryConsume())
            return delegate.removeLast();
        throw new NoSuchElementException();
    }

    @Override
    public E pollFirst()
    {
        if (rateLimit.tryConsume())
            return delegate.pollFirst();
        return null;
    }

    @Override
    public E pollLast()
    {
        if (rateLimit.tryConsume())
            return delegate.pollLast();
        return null;
    }

    @Override
    public E takeFirst() throws InterruptedException
    {
        rateLimit.consume();
        return delegate.takeFirst();
    }

    @Override
    public E takeLast() throws InterruptedException
    {
        rateLimit.consume();
        return delegate.takeLast();
    }

    @Override
    public E pollFirst(long timeout, TimeUnit unit) throws InterruptedException
    {
        if (rateLimit.consume(1, timeout, unit))
            return delegate.pollFirst(timeout, unit);
        return null;
    }

    @Override
    public E pollLast(long timeout, TimeUnit unit) throws InterruptedException
    {
        if (rateLimit.consume(1, timeout, unit))
            return delegate.pollLast(timeout, unit);
        return null;
    }

    @Override
    public E getFirst()
    {
        return delegate.getFirst();
    }

    @Override
    public E getLast()
    {
        return delegate.getLast();
    }

    @Override
    public E peekFirst()
    {
        return delegate.peekFirst();
    }

    @Override
    public E peekLast()
    {
        return delegate.peekLast();
    }

    @Override
    public boolean removeFirstOccurrence(Object o)
    {
        return delegate.removeFirstOccurrence(o);
    }

    @Override
    public boolean removeLastOccurrence(Object o)
    {
        return delegate.removeLastOccurrence(o);
    }

    @Override
    public boolean add(E e)
    {
        return delegate.add(e);
    }

    @Override
    public boolean offer(E e)
    {
        return delegate.offer(e);
    }

    @Override
    public void put(E e) throws InterruptedException
    {
        putLast(e);
    }

    @Override
    public boolean offer(E e, long timeout, TimeUnit unit) throws InterruptedException
    {
        return delegate.offer(e, timeout, unit);
    }

    @Override
    public E remove()
    {
        return removeFirst();
    }

    @Override
    public E poll()
    {
        return pollFirst();
    }

    @Override
    public E take() throws InterruptedException
    {
        return takeFirst();
    }

    @Override
    public E poll(long timeout, TimeUnit unit) throws InterruptedException
    {
        return pollFirst(timeout, unit);
    }

    @Override
    public E element()
    {
        return getFirst();
    }

    @Override
    public E peek()
    {
        return peekFirst();
    }

    @Override
    public int remainingCapacity()
    {
        return delegate.remainingCapacity();
    }

    @Override
    public int drainTo(Collection<? super E> c)
    {
        int result = 0;
        while (true)
        {
            E next = pollFirst();
            if (next == null)
                break;
            c.add(next);
        }
        return result;
    }

    @Override
    public int drainTo(Collection<? super E> c, int maxElements)
    {
        int result = 0;
        do
        {
            E next = pollFirst();
            if (next == null)
                break;
            c.add(next);
        }
        while (result < maxElements);
        return result;
    }

    @Override
    public void push(E e)
    {
        addFirst(e);
    }

    @Override
    public E pop()
    {
        return removeFirst();
    }

    @Override
    public boolean remove(Object o)
    {
        return removeFirstOccurrence(o);
    }

    @Override
    public int size()
    {
        return delegate.size();
    }

    @Override
    public boolean contains(Object o)
    {
        return delegate.contains(o);
    }

    @Override
    public Object[] toArray()
    {
        return delegate.toArray();
    }

    @Override
    public <T> T[] toArray(T[] a)
    {
        return delegate.toArray(a);
    }

    @Override
    public String toString()
    {
        return delegate.toString();
    }

    @Override
    public void clear()
    {
        delegate.clear();
    }

    @Override
    public Iterator<E> iterator()
    {
        return wrap(delegate.iterator());
    }

    /**
     * @param delegateIterator the iterator to delegate to
     * @return an iterator that respects the rate-limit
     */
    private Iterator<E> wrap(Iterator<E> delegateIterator)
    {
        return new Iterator<E>()
        {
            private E previousElement = null;

            @Override
            public boolean hasNext()
            {
                return delegateIterator.hasNext();
            }

            @Override
            public E next()
            {
                return delegateIterator.next();
            }

            @Override
            public void remove()
            {
                if (previousElement == null)
                    throw new IllegalStateException("next() not invoked, or remove() already invoked");
                try
                {
                    rateLimit.consume();
                }
                catch (InterruptedException e)
                {
                    throw new IllegalStateException(e);
                }
                delegateIterator.remove();
                previousElement = null;
            }
        };
    }

    @Override
    public Iterator<E> descendingIterator()
    {
        return wrap(delegate.descendingIterator());
    }

    @Override
    public boolean addAll(Collection<? extends E> c)
    {
        requireThat("c", c).isNotNull().isNotEqualTo("this", this);
        boolean modified = false;
        for (E e: c)
            if (add(e))
                modified = true;
        return modified;
    }

    @Override
    public boolean isEmpty()
    {
        return delegate.isEmpty();
    }

    @Override
    public boolean containsAll(Collection<?> c)
    {
        return delegate.containsAll(c);
    }

    @Override
    public boolean removeAll(Collection<?> c)
    {
        Iterator<E> i = iterator();
        boolean modified = true;
        while (i.hasNext())
        {
            E element = i.next();
            if (c.contains(element))
            {
                i.remove();
                modified = true;
            }
        }
        return modified;
    }

    @Override
    public boolean retainAll(Collection<?> c)
    {
        Iterator<E> i = iterator();
        boolean modified = true;
        while (i.hasNext())
        {
            E element = i.next();
            if (!c.contains(element))
            {
                i.remove();
                modified = true;
            }
        }
        return modified;
    }

    @Override
    public int hashCode()
    {
        return delegate.hashCode();
    }

    @Override
    public boolean equals(Object obj)
    {
        return delegate.equals(obj);
    }
}

Upvotes: 5

Related Questions