Zachary Turner
Zachary Turner

Reputation: 772

Python asyncio wait() with cumulative timeout

I am writing a job scheduler where I schedule M jobs across N co-routines (N < M). As soon as one job finishes, I add a new job so that it can start immediately and run in parallel with the other jobs. Additionally, I would like to ensure that no single job takes more than a certain fixed amount of time. Any jobs that take too long should be cancelled. I have something pretty close, like this:

def update_run_set(waiting, running, max_concurrency):
    number_to_add = min(len(waiting), max_concurrency - len(running))
    for i in range(0, number_to_add):
        next_one = waiting.pop()
        running.add(next_one)

async def _run_test_invocations_asynchronously(jobs:List[MyJob], max_concurrency:int, timeout_seconds:int):
    running = set()     # These tasks are actively being run
    waiting = set()     # These tasks have not yet started

    waiting = {_run_job_coroutine(job) for job in jobs}
    update_run_set(waiting, running, max_concurrency)

    while len(running) > 0:
        done, running = await asyncio.wait(running, timeout=timeout_seconds,
                                           return_when=asyncio.FIRST_COMPLETED)
        if not done:
            timeout_count = len(running)
            [r.cancel() for r in running]                # Start cancelling the timed out jobs
            done, running = await asyncio.wait(running)  # Wait for cancellation to finish 
            assert(len(done) == timeout_count)
            assert(len(running) == 0)
        else:
            for d in done:
                job_return_code = await d

        if len(waiting) > 0:
            update_run_set(waiting, running, max_concurrency)
            assert(len(running) > 0)

The problem here is that say my timeout is 5 seconds, and I'm scheduling 3 jobs across 4 cores. Job A takes 2 seconds, Job B takes 6 seconds and job C takes 7 seconds.

We have something like this:

  t=0     t=1     t=2     t=3     t=4     t=5     t=6     t=7
-------|-------|-------|-------|-------|-------|-------|-------|
AAAAAAAAAAAAAAA
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC

However, at t=2 the asyncio.await() call returns because A completed. It then loops back up to the top and runs again. At this point B has already been running for 2 seconds, but since it starts the countdown over, and only has 4 seconds remaining until it completes, B will appear to be successful. So after 4 seconds we return again, B is successful, then we start the loop over and now C completes.

How do I make it so that B and C both fail? I somehow need the time to be preserved across calls to asyncio.wait().

One idea that I had is to do my own bookkeeping of how much time each job is allowed to continue running, and pass the minimum of these into asyncio.wait(). Then when something times out, I can cancel only those jobs whose time remaining was equal to the value I passed in for timeout_seconds.

This requires a lot of manual bookkeeping on my part though, and I can't help but wonder about floating point problems which cause me to decide that it's not time to cancel a job even though it really is). So I can't help but think that there's something easier. Would appreciate any ideas.

Upvotes: 0

Views: 1230

Answers (1)

user4815162342
user4815162342

Reputation: 154836

You can wrap each job into a coroutine that checks its timeout, e.g. using asyncio.wait_for. Limiting the number of parallel invocations could be done in the same coroutine using an asyncio.Semaphore. With those two combined, you only need one call to wait() or even just gather(). For example (untested):

# Run the job, limiting concurrency and time. This code could likely
# be part of _run_job_coroutine, omitted from the question.
async def _run_job_with_limits(job, sem, timeout):
    async with sem:
        try:
            await asyncio.wait_for(_run_job_coroutine(job), timeout)
        except asyncio.TimeoutError:
            # timed out and canceled, decide what you want to return
            pass

async def _run_test_invocations_async(jobs, max_concurrency, timeout):
    sem = asyncio.Semaphore(max_concurrency)
    return await asyncio.gather(
        *(_run_job_with_limits(job, sem, timeout) for job in jobs)
    )

Upvotes: 4

Related Questions