Reputation: 5412
MWE is below.
My code spawn processes with torch.multiprocessing.Pool
and I manage the communication with the parent with a JoinableQueue
. I followed some online guide to handle gracefully CTRL+C
. Everything works fine. In some cases (my code has more things than the MWE), though, I encounter errors in the function ran by the children (online_test()
). If that happens, the code just hangs forever, because the children do not notify the parent that something happened. I tried adding try ... except ... finally
in the main children loop, with queue.task_done()
in finally
, but nothing changed.
I need the parent to be notified about any children error and terminate everything gracefully. How could I do that? Thanks!
EDIT
Suggested solution does not work. The handler catches the exception but the main code is left hanging because it waits for the queue to be empty.
import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def error_handler(exception):
print(f'{exception} occurred, terminating pool.')
pool.terminate()
def online_test(queue):
while True:
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
if epoch == 1:
raise NotImplementedError
queue.task_done()
if __name__ == '__main__':
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
pool = mp.Pool(initializer=initializer)
pool.apply_async(online_test,
args=(test_queue,), error_callback=error_handler)
for i in ['a', 'b', 'c']:
test_queue.put((0, i))
try:
for epoch in range(10):
print('training epoch', epoch)
print('... waiting for testing before moving on to next epoch ...')
test_queue.join()
print(f'... epoch {epoch} testing is done')
for i in ['a', 'b', 'c']:
test_queue.put((epoch + 1, i))
for i in ['a', 'b', 'c']:
test_queue.put((-1, STOP))
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()
Upvotes: 1
Views: 778
Reputation: 123463
I got the code in your EDIT to work by doing a these two things in the error handler function:
test_queue
.aborted
to true to indicate processing should stop.Then in the __main__
process I added code to check the aborted
flag before waiting for the previous epoch to finish and starting another.
Using a global
seems a little hacky, but it works because the error handler function is executed as part of the main process, so has access to its globals. I remember when that detail dawned on me when I was working on the linked answer — and as you can see — it can prove to important/useful.
import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def error_handler(exception):
print(f'{exception=} occurred, terminating pool.')
pool.terminate()
print('pool terminated.')
while not test_queue.empty():
try:
test_queue.task_done()
except ValueError:
break
print(f'test_queue cleaned.')
global aborted
aborted = True # Indicate an error occurred to the main process.
def online_test(queue):
while True:
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
if epoch == 1:
raise NotImplementedError('epoch == 1') # Fake error for testing.
queue.task_done()
if __name__ == '__main__':
aborted = False
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
pool = mp.Pool(initializer=initializer)
pool.apply_async(online_test, args=(test_queue,), error_callback=error_handler)
for i in ['a', 'b', 'c']:
test_queue.put((0, i))
try:
for epoch in range(10):
if aborted: # Error occurred?
print('ABORTED by error_handler!')
break
print('training epoch', epoch)
print('... waiting for testing before moving on to next epoch ...')
test_queue.join()
print(f'... epoch {epoch} testing is done')
for i in ['a', 'b', 'c']:
test_queue.put((epoch + 1, i))
for i in ['a', 'b', 'c']:
test_queue.put((-1, STOP))
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()
Output from sample run:
training epoch 0
... waiting for testing before moving on to next epoch ...
testing function for a has started for epoch 0
testing function for b has started for epoch 0
testing function for c has started for epoch 0
... epoch 0 testing is done
testing function for a has started for epoch 1
training epoch 1
... waiting for testing before moving on to next epoch ...
exception=NotImplementedError('epoch == 1') occurred, terminating pool.
pool terminated.
... epoch 1 testing is done
test_queue cleaned.
ABORTED by error_handler!
Press any key to continue . . .
Upvotes: 1