Erik Swan
Erik Swan

Reputation: 595

asyncio.wait_for_completion() on dynamic list of tasks

I am writing a simple client to interact with a possibly overloaded and unreliable webserver. I assume that for any individual request, the server may never respond (request will timeout), or may respond with an error after a lengthy delay.

Because of this, for each "request" I want to issue repeated requests according to the following logic:

I can accomplish something close to this if I issue a fixed number of requests simultaneously at the beginning, then use asyncio.as_completed() to handle the requests as they finish and cancel any remaining pending requests:

import asyncio
import logging
import random
import time
from sys import stdout


class FailedRequest(Exception):
    pass

async def get():
    '''A simple mock async GET request that returns a randomized status after a randomized delay'''
    await asyncio.sleep(random.uniform(0,10))
    return random.choices([200, 500], [0.2, 0.8])[0]

async def fetch(id):
    '''Makes a request using get(), checks response, and handles cancellation'''
    logging.info(f"Sending request {id}.")
    start_time = time.perf_counter()
    try:
        response = await get()
        elapsed_time = time.perf_counter() - start_time
            
        if response != 200:
            logging.error(f"Request {id} failed after {elapsed_time:.2f}s: {response}")
            raise FailedRequest()
        else:
            logging.info(f"Request {id} succeeded ({response}) after {elapsed_time:.2f}s!")
    except asyncio.CancelledError:
        logging.info(f"Cancelled request {id} after {time.perf_counter() - start_time:.2f}s.")
        raise

async def main():
    # Create 10 unique Tasks that wrap the fetch() coroutine
    tasks = [asyncio.create_task(fetch(i)) for i in range(10)]

    # Iterate through the tasks as they are completed
    for coro in asyncio.as_completed(tasks):
        try:
            # Wait for the next task to finish. If the request errored out,
            # this line will raise a FailedRequest exception (caught below)
            await coro

            # If we get here, then a request succeeded. Cancel all of the tasks we started.
            for t in tasks:
                t.cancel()

        except (FailedRequest, asyncio.CancelledError) as e:
            pass

    logging.info("Finished!")
                
if __name__ == '__main__':
    logging.basicConfig(stream=stdout, level=logging.INFO, format='%(asctime)s:%(levelname)s: %(message)s')
    random.seed(3)
    asyncio.run(main())

Output:

2020-09-22 18:07:35,634:INFO: Sending request 0.
2020-09-22 18:07:35,635:INFO: Sending request 1.
2020-09-22 18:07:35,635:INFO: Sending request 2.
2020-09-22 18:07:35,635:INFO: Sending request 3.
2020-09-22 18:07:35,636:INFO: Sending request 4.
2020-09-22 18:07:35,636:INFO: Sending request 5.
2020-09-22 18:07:35,636:INFO: Sending request 6.
2020-09-22 18:07:35,636:INFO: Sending request 7.
2020-09-22 18:07:35,636:INFO: Sending request 8.
2020-09-22 18:07:35,637:INFO: Sending request 9.
2020-09-22 18:07:35,786:ERROR: Request 6 failed after 0.15s: 500
2020-09-22 18:07:36,301:ERROR: Request 5 failed after 0.66s: 500
2020-09-22 18:07:37,993:ERROR: Request 9 failed after 2.35s: 500
2020-09-22 18:07:38,023:ERROR: Request 0 failed after 2.39s: 500
2020-09-22 18:07:38,236:ERROR: Request 8 failed after 2.60s: 500
2020-09-22 18:07:39,351:INFO: Request 2 succeeded (200) after 3.72s!
2020-09-22 18:07:39,351:INFO: Cancelled request 1 after 3.72s.
2020-09-22 18:07:39,351:INFO: Cancelled request 3 after 3.72s.
2020-09-22 18:07:39,352:INFO: Cancelled request 4 after 3.72s.
2020-09-22 18:07:39,352:INFO: Cancelled request 7 after 3.72s.
2020-09-22 18:07:39,352:INFO: Finished!

However, I am struggling to understand a clean way to start by issuing a single request, then issue additional requests every second until one of the requests is successful, while still keeping track of all unfinished requests and cancelling any that are still pending.

This is as close as I've gotten:

import asyncio
import logging
import random
import time
from sys import stdout


class FailedRequest(Exception):
    pass

async def get():
    '''A simple mock async GET request that returns a randomized status after a randomized delay'''
    await asyncio.sleep(random.uniform(0,10))
    return random.choices([200, 500], [0.2, 0.8])[0]

async def fetch(id):
    '''Makes a request using get(), checks response, and handles cancellation'''
    logging.info(f"Sending request {id}.")
    start_time = time.perf_counter()
    try:
        response = await get()
        elapsed_time = time.perf_counter() - start_time
            
        if response != 200:
            logging.error(f"Request {id} failed after {elapsed_time:.2f}s: {response}")
            raise FailedRequest()
        else:
            logging.info(f"Request {id} succeeded ({response}) after {elapsed_time:.2f}s!")
    except asyncio.CancelledError:
        logging.info(f"Cancelled request {id} after {time.perf_counter() - start_time:.2f}s.")
        raise

async def issue_requests(finished, requests):
    i = 0
    while not finished.is_set():
        requests.add(asyncio.create_task(fetch(i)))
        await asyncio.sleep(1)
        i += 1

async def handle_requests(finished, requests):
    # Iterate through the requests as they are completed
    for coro in asyncio.as_completed(requests):
        try:
            # Wait for the next task to finish. If the request errored out,
            # this line will raise a FailedRequest exception (caught below)
            await coro

            # If we get here, then a request succeeded. Cancel all of the tasks we started.
            finished.set()
            for r in requests:
                r.cancel()

        except (FailedRequest, asyncio.CancelledError):
            pass


async def main():
    finished = asyncio.Event()
    requests = set()

    await asyncio.gather(issue_requests(finished, requests), handle_requests(finished, requests))
    logging.info("Finished!")
                
if __name__ == '__main__':
    logging.basicConfig(stream=stdout, level=logging.INFO, format='%(asctime)s:%(levelname)s: %(message)s')
    random.seed(3)
    asyncio.run(main())

However, although the requests are launched as expected, the process does not stop when the first successful request returns:

2020-09-22 18:03:38,256:INFO: Sending request 0.
2020-09-22 18:03:39,264:INFO: Sending request 1.
2020-09-22 18:03:40,265:INFO: Sending request 2.
2020-09-22 18:03:40,643:ERROR: Request 0 failed after 2.39s: 500
2020-09-22 18:03:41,281:INFO: Sending request 3.
2020-09-22 18:03:42,281:INFO: Sending request 4.
2020-09-22 18:03:42,948:INFO: Request 4 succeeded (200) after 0.67s!
# requests 1, 2, and 3 should be cancelled here and the script should finish
2020-09-22 18:03:43,279:INFO: Sending request 5.
2020-09-22 18:03:43,976:ERROR: Request 2 failed after 3.71s: 500
2020-09-22 18:03:44,281:INFO: Sending request 6.
2020-09-22 18:03:44,718:ERROR: Request 1 failed after 5.45s: 500
2020-09-22 18:03:45,295:INFO: Sending request 7.
2020-09-22 18:03:46,307:INFO: Sending request 8.
...

I think the issue is that when asyncio.as_completed(requests) is called in handle_requests(), requests is an empty set, so as_completed() returns an empty iterator and handle_requests() returns immediately.

It feels like it should be possible to do this with asyncio at a high level, but I am struggling to figure it out.

Upvotes: 1

Views: 961

Answers (2)

user4815162342
user4815162342

Reputation: 155670

Your requirements are deceptively simple at first, but actually require some thought. If I understood them correctly, you want something like this:

async def _spawner(async_fn, done):
    # spawn a new task every second, notifying the caller when
    # any task completes
    running = set()
    def when_done(task):
        running.remove(task)
        done.put_nowait(task)

    while True:
        new_task = asyncio.create_task(async_fn())
        new_task.add_done_callback(when_done)
        running.add(new_task)
        try:
            await asyncio.sleep(1)
        except asyncio.CancelledError:
            # we're canceled, cancel the tasks that are still running
            for t in running:
                t.cancel()
            raise

async def try_until_successful(async_fn):
    done = asyncio.Queue()
    # run the spawner in the background so it can run
    # independently of us
    spawner_task = asyncio.create_task(_spawner(async_fn, done))

    # collect completed tasks until the one we're happy with
    while True:
        task = await done.get()
        if task.exception() is None:
            # task didn't raise - we got our result!
            spawner_task.cancel()
            return task.result()

Those functions don't assume anything about the coroutine passed as async_fn. You can use lambda or functools.partial to create any kind of callable, as long as it ends up returning a coroutine object (the object returned by just calling an async def and what you normally pass to await). Also, they don't use global flags, so you could await several instances of try_until_successful in parallel.

To invoke fetch() with monotonically increasing IDs you could call it like this:

async def main():
    cnt = 0
    async def fetch_incrementing():
        nonlocal cnt
        cnt += 1
        await fetch(cnt)
    await try_until_successful(fetch_incrementing)

Upvotes: 0

GProst
GProst

Reputation: 10237

You could do it like this for example (using while loop):

is_finished = False
tasks = []

def cancel_tasks():
  for t in tasks:
    t.cancel()

async def fetch(count):
  '''Makes a request using get(), checks response, and handles cancellation'''
  logging.info(f"Sending request {count}.")
  start_time = time.perf_counter()
  try:
    response = await get()
    elapsed_time = time.perf_counter() - start_time

    if response != 200:
      logging.error(f"Request {count} failed after {elapsed_time:.2f}s: {response}")
      raise FailedRequest()
    else:
      global is_finished
      is_finished = True
      logging.info(f"Request {count} succeeded ({response}) after {elapsed_time:.2f}s!")
      cancel_tasks()

  except asyncio.CancelledError:
    logging.info(f"Cancelled request {count} after {time.perf_counter() - start_time:.2f}s.")
    raise

async def main():
  count = 0
  while not is_finished:
    tasks.append(asyncio.create_task(fetch(count)))
    await asyncio.sleep(1)
    count += 1

  # Wait for all tasks to cancel:
  await asyncio.wait(tasks)

  logging.info("Finished!")

EDIT: slightly improved so that it cancels all the tasks ASAP and then waits for all of them to be canceled before logging 'Finished'

Upvotes: 1

Related Questions