Andy
Andy

Reputation: 83

Sharing an MPI communicator using pybind11

Suppose that I have created a wrapper around an MPI communicator:

class Communicator {
   public:
      Communicator() : comm(MPI_COMM_WORLD) {}

      Communicator(int const color, int const key) {
        MPI_Comm_split(MPI_COMM_WORLD, color, key, &comm);
      }

      Communicator(MPI_Comm comm) : comm(comm) {}

      MPI_Comm GetComm() const { return comm; }
    private:
      MPI_Comm comm;
}; 

I would like to use pybind11 to create a python wrapper around this object that looks something like this:

void CommunicatorWrapper(pybind11::module &m) {
   py::class_<Communicator, std::shared_ptr<Communicator> > commWrap(m, "Communicator");

   commWrap.def(py::init( []() { return new Communicator(); } ));
   commWrap.def(py::init( [](int const color, int const key) { return new Communicator(color, key); } ));
   commWrap.def(py::init( [](MPI_Comm comm) { return new Communicator(comm); } ));
   commWrap.def("GetComm", &Communicator::GetComm);
}

However, I'd like the MPI_Comm type that python sees to be mpi4py.MPI.Comm. Is this possible? If so, how?

The above (naive) implementation results in the following behavior:

comm = Communicator(MPI.COMM_WORLD)

Error:

TypeError: __init__(): incompatible constructor arguments. The following argument types are supported:
1. Communicator()
2. Communicator(arg0: int, arg1: int)
3. Communicator(arg0: int)

and

comm = Communicator()
print(comm.GetComm())

prints -2080374784. This behavior makes sense given what MPI_Comm is but is obviously not the functionality I need.

Upvotes: 2

Views: 486

Answers (1)

Andy
Andy

Reputation: 83

I solved this by changing the wrapper to

#include <mpi4py/mpi4py.h>

pybind11::handle CallGetComm(Communicator *comm) {
    const int rc = import_mpi4py();
    return pybind11::handle(PyMPIComm_New(comm->GetComm()));;
}

void CommunicatorWrapper(pybind11::module &m) {
   py::class_<Communicator, std::shared_ptr<Communicator> > commWrap(m, "Communicator");

   commWrap.def(py::init( []() { return new Communicator(); } ));
   commWrap.def(py::init( [](int const color, int const key) { return new Communicator(color, key); } ));
   commWrap.def(py::init( [](pybind11::handle const& comm) {
     const int rc = import_mpi4py();
     assert(rc==0);
     return new Communicator(*PyMPIComm_Get(comm.ptr()));
    } ));
   commWrap.def("GetComm", &CallGetComm);
}

Upvotes: 1

Related Questions