Simon
Simon

Reputation: 5412

How to handle errors in multiprocesses?

MWE is below.
My code spawn processes with torch.multiprocessing.Pool and I manage the communication with the parent with a JoinableQueue. I followed some online guide to handle gracefully CTRL+C. Everything works fine. In some cases (my code has more things than the MWE), though, I encounter errors in the function ran by the children (online_test()). If that happens, the code just hangs forever, because the children do not notify the parent that something happened. I tried adding try ... except ... finally in the main children loop, with queue.task_done() in finally, but nothing changed.

I need the parent to be notified about any children error and terminate everything gracefully. How could I do that? Thanks!

EDIT
Suggested solution does not work. The handler catches the exception but the main code is left hanging because it waits for the queue to be empty.

import signal
import numpy as np
import multiprocessing as mp

STOP = 'STOP'

def initializer():
    """Ignore CTRL+C in the worker process."""
    signal.signal(signal.SIGINT, signal.SIG_IGN)

def error_handler(exception):
    print(f'{exception} occurred, terminating pool.')
    pool.terminate()

def online_test(queue):
    while True:
        epoch, data_id = queue.get()
        if data_id == STOP:
            print(f'... test function is stopping ...')
            break
        print(f'testing function for {data_id} has started for epoch {epoch}')
        if epoch == 1:
            raise NotImplementedError
        queue.task_done()


if __name__ == '__main__':
    mp.set_start_method('spawn')
    manager = mp.Manager()
    test_queue = manager.JoinableQueue()
    pool = mp.Pool(initializer=initializer)
    pool.apply_async(online_test,
        args=(test_queue,), error_callback=error_handler)
    for i in ['a', 'b', 'c']:
        test_queue.put((0, i))

    try:
        for epoch in range(10):
            print('training epoch', epoch)
            print('... waiting for testing before moving on to next epoch ...')
            test_queue.join()
            print(f'... epoch {epoch} testing is done')
            for i in ['a', 'b', 'c']:
                test_queue.put((epoch + 1, i))

        for i in ['a', 'b', 'c']:
            test_queue.put((-1, STOP))

    except KeyboardInterrupt:
        pool.terminate()
    else:
        pool.close()

    pool.join()

Upvotes: 1

Views: 778

Answers (1)

martineau
martineau

Reputation: 123463

I got the code in your EDIT to work by doing a these two things in the error handler function:

  1. Emptied the test_queue.
  2. Set a global flag variable named aborted to true to indicate processing should stop.

Then in the __main__ process I added code to check the aborted flag before waiting for the previous epoch to finish and starting another.

Using a global seems a little hacky, but it works because the error handler function is executed as part of the main process, so has access to its globals. I remember when that detail dawned on me when I was working on the linked answer — and as you can see — it can prove to important/useful.

import signal
import numpy as np
import multiprocessing as mp

STOP = 'STOP'

def initializer():
    """Ignore CTRL+C in the worker process."""
    signal.signal(signal.SIGINT, signal.SIG_IGN)

def error_handler(exception):
    print(f'{exception=} occurred, terminating pool.')
    pool.terminate()
    print('pool terminated.')
    while not test_queue.empty():
        try:
            test_queue.task_done()
        except ValueError:
            break
    print(f'test_queue cleaned.')
    global aborted
    aborted = True  # Indicate an error occurred to the main process.

def online_test(queue):
    while True:
        epoch, data_id = queue.get()
        if data_id == STOP:
            print(f'... test function is stopping ...')
            break
        print(f'testing function for {data_id} has started for epoch {epoch}')
        if epoch == 1:
            raise NotImplementedError('epoch == 1')  # Fake error for testing.
        queue.task_done()


if __name__ == '__main__':
    aborted = False
    mp.set_start_method('spawn')
    manager = mp.Manager()
    test_queue = manager.JoinableQueue()
    pool = mp.Pool(initializer=initializer)
    pool.apply_async(online_test,  args=(test_queue,), error_callback=error_handler)
    for i in ['a', 'b', 'c']:
        test_queue.put((0, i))

    try:
        for epoch in range(10):
            if aborted:  # Error occurred?
                print('ABORTED by error_handler!')
                break
            print('training epoch', epoch)
            print('... waiting for testing before moving on to next epoch ...')
            test_queue.join()
            print(f'... epoch {epoch} testing is done')
            for i in ['a', 'b', 'c']:
                test_queue.put((epoch + 1, i))

        for i in ['a', 'b', 'c']:
            test_queue.put((-1, STOP))

    except KeyboardInterrupt:
        pool.terminate()
    else:
        pool.close()

    pool.join()

Output from sample run:

training epoch 0
... waiting for testing before moving on to next epoch ...
testing function for a has started for epoch 0
testing function for b has started for epoch 0
testing function for c has started for epoch 0
... epoch 0 testing is done
testing function for a has started for epoch 1
training epoch 1
... waiting for testing before moving on to next epoch ...
exception=NotImplementedError('epoch == 1') occurred, terminating pool.
pool terminated.
... epoch 1 testing is done
test_queue cleaned.
ABORTED by error_handler!
Press any key to continue . . .

Upvotes: 1

Related Questions