Andre
Andre

Reputation: 558

Fail fast with MPI4PY

I'd like the following behavior when running an MPI script with mpi4py: when any process throws an exception, mpirun (and its spawned processes) should immediately exit with non-zero error codes. But instead, I find that execution continues even if one or more processes throws an exception.

I am using mpi4py 3.0.0 with OpenMPI 2.1.2. I'm running this script with mpirun --verbose -mca orte_abort_on_non_zero_status 1 -n 4 python my_script.py. I expected this to immediately end before the sleep is hit, but instead, processes with ranks != 0 sleep:

import time
import mpi4py

def main():
    import mpi4py.MPI
    mpi_comm = mpi4py.MPI.COMM_WORLD
    if mpi_comm.rank == 0:
        raise ValueError('Failure')


    print('{} continuing to execute'.format(mpi_comm.rank))
    time.sleep(10)
    print('{} exiting'.format(mpi_comm.rank)


if __name__ == '__main__':
    main()

How can I get the behavior I'd like (fail quickly if any process fails)?

Thank you!

Upvotes: 5

Views: 1917

Answers (2)

Andre
Andre

Reputation: 558

It turns out mpi4py can be run as a module fixing this issue (internally by calling Abort() like jcgiret says):

mpirun --verbose -mca orte_abort_on_non_zero_status 1 -n 4 python -m mpi4py my_script.py

Upvotes: 0

jcgiret
jcgiret

Reputation: 758

It seems to be an known issue of mpi4py. From https://groups.google.com/forum/#!topic/mpi4py/RovYzJ8qkbc, I read:

mpi4py initializes/finalizes MPI for you. The initialization occurs at import time, and the finalization when the Python process is about to finalize (I'm using Py_AtExit() C-API call to do this). As MPI_Finalize() is collective and likely blocking in most MPI impls, you get the deadlock.

A solution is to override sys.excepthookand call explicitly MPI.COMM_WORLD.Abort in it.

Here is your code modified:

import sys
import time
import mpi4py.MPI
mpi_comm = mpi4py.MPI.COMM_WORLD

def mpiabort_excepthook(type, value, traceback):
    mpi_comm.Abort()
    sys.__excepthook__(type, value, traceback)

def main():
    if mpi_comm.rank == 0:
        raise ValueError('Failure')


    print('{} continuing to execute'.format(mpi_comm.rank))
    time.sleep(10)
    print('{} exiting'.format(mpi_comm.rank))

if __name__ == "__main__":
    sys.excepthook = mpiabort_excepthook
    main()
    sys.excepthook = sys.__excepthook__

Upvotes: 6

Related Questions