gdamms
gdamms

Reputation: 1

How to propagate KeyboardInterrupt to multiprocessing Pool workers and gracefully exit?

I am using the multiprocessing.Pool in Python to parallelize a task that requires entering and exiting a context manager for each item in the workload. I want to ensure that when a KeyboardInterrupt occurs, the program gracefully exits by allowing each worker process to complete the exit method of its context manager by propagating the interruption (or something like that).

Here is a simple code snippet to illustrate my thoughts:

class MyContext:
    def __init__(self, name):
        self.name = name

    def __enter__(self):
        print(f'Entering context {self.name}')
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        print(f'Exiting context {self.name}')
        time.sleep(5)
        print(f'Exited context {self.name}')


def process_name(name):
    with MyContext(name):
        print(f'Processing {name}')
        time.sleep(100)
        print(f'Finished {name}')


names = ['Alice', 'Bob']

with multiprocessing.Pool(2) as pool:
    pool.map(process_name, names)

Here is the result I expect:

Entering context Alice
Processing Alice
Entering context Bob
Processing Bob
^CExiting context Alice
Exiting context Bob
Exited context Alice
Exited context Bob
KeyboardInterrupt (maybe multiple times)

Here is the actual output:

Entering context Alice
Processing Alice
Entering context Bob
Processing Bob
^CExiting context Alice
Exiting context Bob
KeyboardInterrupt

The context manager's exit method is called, but it doesn't finish before the program crashes. What should I change to make it work as intended?

EDIT: If there are reamaining tasks, I don't want the pool to keep mapping. So if I set the number of threads to 1, I don't want to get this:

Entering context Alice
Processing Alice
^CExiting context Alice, exc_type = <class 'KeyboardInterrupt'>
Exited context Alice
Entering context Bob
Processing Bob
Finished Bob
Exiting context Bob, exc_type = None
Exited context Bob

but simply this:

Entering context Alice
Processing Alice
^CExiting context Alice, exc_type = <class 'KeyboardInterrupt'>
Exited context Alice

Upvotes: -1

Views: 97

Answers (1)

Booboo
Booboo

Reputation: 44283

Update

Based on your updated question, this is my updated answer. Now the main process has to handle a keyboard interrupt. However, if it uses the multiprocessing.pool.map method the interrupt will not be handled until the map function completes, which will be only after all submitted tasks are executed. This is not what you want and so we cannot use the map method. Instead we use the map_async method and loop testing whether it has completed. Even so, the interrupt won't be handled until the statement result = async_result.get(0) completes assuming that is the statement being executed by the main process when the interrupt occurs. In this case there could be a slight delay between the interrupt occurring and the pool terminated, which occurs automatically when the with multiprocessing.Pool(1) as pool: block is exited. So you will not see the __exit__ method executing the print(f'Exited context {self.name}') statement, which only occurs 5 seconds after prior message is output and the pool will have been terminated long before then.

import multiprocessing
import time
import signal

class MyContext:
    def __init__(self, name):
        self.name = name

    def __enter__(self):
        print(f'Entering context {self.name}')
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        print(f'Exiting context {self.name}, exc_type = {exc_type}')
        time.sleep(5)
        print(f'Exited context {self.name}')
        # We need this to cause the with statement to suppress the exception
        # and continue execution with the statement immediately
        # following the with statement:
        return True

def process_name(name):
    with MyContext(name):
        print(f'Processing {name}')
        time.sleep(100)
        print(f'Finished {name}')

names = ['Alice', 'Bob']

# Not optional for Windows:
if __name__ == '__main__':
    try:
        with multiprocessing.Pool(1) as pool:
            async_result = pool.map_async(process_name, names)
            while True:
                try:
                    result = async_result.get(0)
                except multiprocessing.TimeoutError:
                    pass
                else:
                    break
    except KeyboardInterrupt:
        print('The pool and all tasks not currently executing have been terminated.')

    ... # Any additional code
    print('I am additional code.')

Prints:

Entering context Alice
Processing Alice
^CExiting context Alice, exc_type = <class 'KeyboardInterrupt'>
The pool and all tasks not currently executing have been terminated.
I am additional code.

Update 2

The previous update terminated the pool guaranteeing that tasks on the pool's input queue will never run. But as a side effect, code in the method __exit__ might not complete. So if we cannot terminate the pool for that reason, there has to be a mechanism to effectively prevent tasks that haven't started executing from running. The only way I see to do that is to set an abort flag to True on a keyboard interrupt and every task when it starts running will initially check the flag and if it is set, will immediately return. If the MyContext context manager is already running when the exception has occured, it will be handled by __exit__, which will check to see if a KeyboardInterrupt has occurred and will prevent the propagation of the exception.

import multiprocessing
import time
import signal

def init_pool():
    global abort_flag

    def set_abort(*args):
        """Executed when Ctrl-C has been entered.
        We set the global abort_flag so that tasks waiting to run will
        immediately terminate when started. We then re-raise the exception
        so that it will be handled by MyContext.__exit__.
        """

        global abort_flag

        abort_flag = True
        raise KeyboardInterrupt()

    abort_flag = False
    signal.signal(signal.SIGINT, set_abort)

class MyContext:
    def __init__(self, name):
        self.name = name

    def __enter__(self):
        print(f'Entering context {self.name}')
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        print(f'Exiting context {self.name}')
        time.sleep(1)
        print(f'Exited context {self.name}')
        # Return True to supress exception
        return exc_type is KeyboardInterrupt

def process_name(name):
    if abort_flag:   # Should we abort?
        return  # Perform no processing

    with MyContext(name):
        print(f'Processing {name}')
        time.sleep(100)
        print(f'Finished {name}')

names = ['Alice', 'Bob']

# Not optional for Windows:
if __name__ == '__main__':
    with multiprocessing.Pool(1, initializer=init_pool) as pool:
        # For Linux do this after the pool has been created
        # so that the child processes do not inherit this handler:
        signal.signal(signal.SIGINT, signal.SIG_IGN)  # Ignore keyboard interrupts
        pool.map(process_name, names)

    ... # Any additional code
    print('I am additional code.')

Prints:

Entering context Alice
Processing Alice
^CExiting context Alice
Exited context Alice
I am additional code.

Upvotes: 0

Related Questions