danijar
danijar

Reputation: 34215

How to immediately re-raise an exception thrown in any of the worker threads?

I'm using ThreadPoolExecutor and need to abort the whole computation in case any of the worker threads fails.

Example 1. This prints Success regardless of the error, since ThreadPoolExecutor does not re-raise exceptions automatically.

from concurrent.futures import ThreadPoolExecutor

def task():
    raise ValueError

with ThreadPoolExecutor() as executor:
    executor.submit(task)
print('Success')

Example 2. This correctly crashes the main thread because .result() re-raises exceptions. But it waits for the first task to finish, so the main thread experiences the exception with a delay.

import time
from concurrent.futures import ThreadPoolExecutor

def task(should_raise):
    time.sleep(1)
    if should_raise:
        raise ValueError

with ThreadPoolExecutor() as executor:
    executor.submit(task, False).result()
    executor.submit(task, True).result()
print('Success')

How can I notice a worker exception in the main thread (more or less) immediately after it occurred, to handle the failure and abort the remaining workers?

Upvotes: 3

Views: 3028

Answers (2)

danijar
danijar

Reputation: 34215

First of all, we have to submit the tasks before requesting their results. Otherwise, the threads are not even running in parallel:

futures = []
with ThreadPoolExecutor() as executor:
    futures.append(executor.submit(good_task))
    futures.append(executor.submit(bad_task))
for future in futures:
    future.result()

Now we can store the exception information in a variable that is available to both the main thread and the worker threads:

exc_info = None

The main thread cannot really kill its sub processes, so we have the workers check for the exception information to be set and stop:

def good_task():
    global exc_info
    while not exc_info:
        time.sleep(0.1)

def bad_task():
    global exc_info
    time.sleep(0.2)
    try:
        raise ValueError()
    except Exception:
        exc_info = sys.exc_info()

After all threads terminated, the main thread can check the variable holding the exception information. If it's populated, we re-raise the exception:

if exc_info:
    raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
print('Success')

Upvotes: 2

Laurent LAPORTE
Laurent LAPORTE

Reputation: 22982

I think, I'll implement it like that:

I the main process, I create 2 queues:

  1. One to report exceptions,
  2. One to notify cancellation.

::

import multiprocessing as mp

error_queue = mp.Queue()
cancel_queue = mp.Queue()

I create each ThreadPoolExecutor, and pass these queues as parameters.

class MyExecutor(concurrent.futures.ThreadPoolExecutor):
    def __init__(self, error_queue, cancel_queue):
        self.error_queue : error_queue
        self.cancel_queue = cancel_queue

Each ThreadPoolExecutor has a main loop. In this loop, I first scan the cancel_queue to see if a "cancel" message is available.

In the main loop, I also implement an exception manager. And if an erreur occurs, I raise an exception:

self.status = "running"
with True:  # <- or something else
    if not self.cancel_queue.empty():
        self.status = "cancelled"
        break
    try:
        # normal processing
        ...
    except Exception as exc:
        # you can log the exception here for debug
        self.error_queue.put(exc)
        self.status = "error"
        break
    time.sleep(.1)

In the main process:

Run all MyExecutor instance.

Scan the error_queue:

while True:
    if not error_queue.empty():
        cancel_queue.put("cancel")
    time.sleep(.1)

Upvotes: 1

Related Questions