mpi4py performance discrepancy after profiling

I have been doing some work with MPI4py arrays, and I recently came across the performance increase after using Scatterv() functions. I have developed a code to inspect the data type of an input object and, in case it is a numeric numpy array, it performs the scattering with Scatterv(), otherwise it does so with a proper-implemented function.

The code looks like this:

import numpy as np
from mpi4py import MPI
import cProfile
from line_profiler import LineProfiler

def ScatterV(object, comm, root = 0):
    optimize_scatter, object_type = np.zeros(1), None

    if rank == root:
        if isinstance(object, np.ndarray):
            if object.dtype in [np.float64, np.float32, np.float16, np.float,
                      , np.int8, np.int16, np.int32, np.int64]:
                optimize_scatter = 1
                object_type = object.dtype

            else: optimize_scatter, object_type = 0, None
        else: optimize_scatter, object_type = 0, None

        optimize_scatter = np.array(optimize_scatter, dtype=np.float64).ravel()

    comm.Bcast([optimize_scatter, 1, MPI.DOUBLE], root=root)
    object_type = comm.bcast(object_type, root=root)

    if int(optimize_scatter) == 1:

        if rank == root:

            displs = [int(i)*object.shape[1] for i in
                          np.linspace(0, object.shape[0], comm.size + 1)]
            counts = [displs[i+1] - displs[i] for i in range(len(displs)-1)]
            lens = [int((displs[i+1] - displs[i])/(object.shape[1]))
                        for i in range(len(displs)-1)]
            displs = displs[:-1]
            shape = object.shape

            object = object.ravel().astype(np.float64, copy=False)

            object, counts, displs, shape, lens = None, None, None, None, None

        counts = comm.bcast(counts, root=root)
        displs = comm.bcast(displs, root=root)
        lens = comm.bcast(lens, root=root)
        shape = list(comm.bcast(shape, root=root))

        shape[0] = lens[rank]
        shape = tuple(shape)

        x = np.zeros(counts[rank])

        comm.Scatterv([object, counts, displs, MPI.DOUBLE], x, root=root)

        return  np.reshape(x, (-1,) + shape[1:]).astype(object_type, copy=False)

        return comm.scatter(object, root=root)

size, rank = comm.Get_size(), comm.Get_rank()

if rank == 0:
    arra = (np.random.rand(10000000, 10) * 100).astype(np.float64, copy=False)
else: arra = None

lp = LineProfiler()

lp_wrapper = lp(ScatterV)
lp_wrapper(arra, comm)

if rank == 4: lp.print_stats()

pr = cProfile.Profile()

f2 = ScatterV(arra, comm)


if rank == 4: pr.print_stats()

The analysis with LineProfiler yields the following results [cut to show conflictive lines only]:

Timer unit: 1e-06 s

Total time: 2.05001 s
File: /media/SETH_DATA/SETH_Alex/BigMPI4py/
Function: ScatterV at line 26

Line #      Hits         Time  Per Hit   % Time  Line Contents
    41         1    1708453.0 1708453.0     83.3      comm.Bcast([optimize_scatter, 1, MPI.DOUBLE], root=root)
    42         1        148.0    148.0      0.0      object_type = comm.bcast(object_type, root=root)
    76         1        264.0    264.0      0.0          counts = comm.bcast(counts, root=root)
    77         1         16.0     16.0      0.0          displs = comm.bcast(displs, root=root)
    78         1         14.0     14.0      0.0          lens = comm.bcast(lens, root=root)
    79         1          9.0      9.0      0.0          shape = list(comm.bcast(shape, root=root))
    86         1     340971.0 340971.0     16.6          comm.Scatterv([object, counts, displs, MPI.DOUBLE], x, root=root)

The analysis with cProfile yields the following results:

         17 function calls in 0.462 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.127    0.127    0.127    0.127 {method 'Bcast' of 'mpi4py.MPI.Comm' objects}
        1    0.335    0.335    0.335    0.335 {method 'Scatterv' of 'mpi4py.MPI.Comm' objects}
        5    0.000    0.000    0.000    0.000 {method 'bcast' of 'mpi4py.MPI.Comm' objects}

In both cases, the Bcast method consumes a lot of time in comparison with ScatterV method. Even more, with LinePprofiler, the Bcast method is 5 times slower than ScatterV method, which seems completely incoherent to me since Bcast is only broadcasting an array of 10 elements.

If I swap the lines 41 and 42, these are the results:


41         1    1666718.0 1666718.0     83.0      object_type = comm.bcast(object_type, root=root)
42         1         47.0     47.0      0.0      comm.Bcast([optimize_scatter, 1, MPI.DOUBLE], root=root)
87         1     341728.0 341728.0     17.0          comm.Scatterv([object, counts, displs, MPI.DOUBLE], x, root=root)


1    0.000    0.000    0.000    0.000 {method 'Bcast' of 'mpi4py.MPI.Comm' objects}
1    0.339    0.339    0.339    0.339 {method 'Scatterv' of 'mpi4py.MPI.Comm' objects}
5    0.129    0.026    0.129    0.026 {method 'bcast' of 'mpi4py.MPI.Comm' objects}

If I vary the size of the array to be scattered, the time consumption of ScatterV and Bcast also vary, at the same rate. For instance, if I increase the size 10 times (100000000), the results are:


41         1   16304301.0 16304301.0     82.8      comm.Bcast([optimize_scatter, 1, MPI.DOUBLE], root=root)
42         1        235.0    235.0      0.0      object_type = comm.bcast(object_type, root=root)
87         1    3393658.0 3393658.0     17.2          comm.Scatterv([object, counts, displs, MPI.DOUBLE], x, root=root)


 1    1.348    1.348    1.348    1.348 {method 'Bcast' of 'mpi4py.MPI.Comm' objects}
    1    4.517    4.517    4.517    4.517 {method 'Scatterv' of 'mpi4py.MPI.Comm' objects}
    5    0.000    0.000    0.000    0.000 {method 'bcast' of 'mpi4py.MPI.Comm' objects}

If instead of selecting the results for rank 4, I select them for any rank > 1, the same result happens. However, for rank = 0 the results differ:


41         1        186.0    186.0      0.0      comm.Bcast([optimize_scatter, 1, MPI.DOUBLE], root=root)
42         1        244.0    244.0      0.0      object_type = comm.bcast(object_type, root=root)
87         1    4722349.0 4722349.0    100.0          comm.Scatterv([object, counts, displs, MPI.DOUBLE], x, root=root)


    1    0.000    0.000    0.000    0.000 {method 'Bcast' of 'mpi4py.MPI.Comm' objects}
    1    5.921    5.921    5.921    5.921 {method 'Scatterv' of 'mpi4py.MPI.Comm' objects}
    5    0.000    0.000    0.000    0.000 {method 'bcast' of 'mpi4py.MPI.Comm' objects}

In this case, the Bcast method has a similar computation time as the rest of bcast methods.

I have also tried to, instead of using Bcaston line 41, using bcast and scatter, which yield the same results.

Given that, I think that the increased time consumption is attributed erroneously only to the first broadcast, which means that both profilers yield false timings for parallelization processes.

I am quite sure that the internal structure of profilers is not done to work with parallelizable functions, but I post this question to know if someone has experienced similar results.

In response to Gilles Gouaillardet, I have included comm.Barrier() to the lines before and after each bcast call, and most of the signal is summed up in these comm.Barrier() calls.

Here is an example with LineProfiler.

Timer unit: 1e-06 s

Total time: 2.17248 s
File: /media/SETH_DATA/SETH_Alex/BigMPI4py/
Function: ScatterV at line 26

Line #      Hits         Time  Per Hit   % Time  Line Contents
    26                                           def ScatterV(object, comm, root = 0):
    27         1          7.0      7.0      0.0      optimize_scatter, object_type = np.zeros(1), None
    29         1          2.0      2.0      0.0      if rank == root:
    30                                                   if isinstance(object, np.ndarray):
    31                                                       if object.dtype in [np.float64, np.float32, np.float16, np.float,
    32                                                                 , np.int8, np.int16, np.int32, np.int64]:
    33                                                           optimize_scatter = 1
    34                                                           object_type = object.dtype
    36                                                       else: optimize_scatter, object_type = 0, None
    37                                                   else: optimize_scatter, object_type = 0, None
    39                                                   optimize_scatter = np.array(optimize_scatter, dtype=np.float64).ravel()
    41         1    1677662.0 1677662.0     77.2      comm.Barrier()
    42         1         76.0     76.0      0.0      comm.Bcast([optimize_scatter, 1, MPI.DOUBLE], root=root)
    43         1        345.0    345.0      0.0      comm.Barrier()
    44         1        111.0    111.0      0.0      object_type = comm.bcast(object_type, root=root)
    45         1        166.0    166.0      0.0      comm.Barrier()
    49         1          7.0      7.0      0.0      if int(optimize_scatter) == 1:
    51         1          2.0      2.0      0.0          if rank == root:
    52                                                       if object.ndim > 1:
    53                                                           displs = [int(i)*object.shape[1] for i in
    54                                                                     np.linspace(0, object.shape[0], comm.size + 1)]
    55                                                       else:
    56                                                           displs = [int(i) for i in np.linspace(0, object.shape[0], comm.size + 1)]
    58                                                       counts = [displs[i+1] - displs[i] for i in range(len(displs)-1)]
    60                                                       if object.ndim > 1:
    61                                                           lens = [int((displs[i+1] - displs[i])/(object.shape[1]))
    62                                                                   for i in range(len(displs)-1)]
    63                                                       else:
    64                                                           lens = [displs[i+1] - displs[i] for i in range(len(displs)-1)]
    66                                                       displs = displs[:-1]
    69                                                       shape = object.shape
    73                                                       if object.ndim > 1:
    74                                                           object = object.ravel().astype(np.float64, copy=False)
    77                                                   else:
    78         1          2.0      2.0      0.0              object, counts, displs, shape, lens = None, None, None, None, None
    80         1        295.0    295.0      0.0          counts = comm.bcast(counts, root=root)
    81         1         66.0     66.0      0.0          displs = comm.bcast(displs, root=root)
    82         1          6.0      6.0      0.0          lens = comm.bcast(lens, root=root)
    83         1          9.0      9.0      0.0          shape = list(comm.bcast(shape, root=root))
    85         1          2.0      2.0      0.0          shape[0] = lens[rank]
    86         1          3.0      3.0      0.0          shape = tuple(shape)
    88         1         33.0     33.0      0.0          x = np.zeros(counts[rank])
    90         1         76.0     76.0      0.0          comm.Barrier()
    91         1     351187.0 351187.0     16.2          comm.Scatterv([object, counts, displs, MPI.DOUBLE], x, root=root)
    92         1     142352.0 142352.0      6.6          comm.Barrier()
    94         1          5.0      5.0      0.0          if len(shape) > 1:
    95         1         66.0     66.0      0.0              return  np.reshape(x, (-1,) + shape[1:]).astype(object_type, copy=False)
    96                                                   else:
    97                                                       return x.view(object_type)
   100                                               else:
   101                                                   return comm.scatter(object, root=root)

The 77.2% of time is spent on the first comm.Barrier() element, so I can safely assume that neither bcast call is taking such an overwhelming amount of time. I will consider adding comm.Barrier()calls for future profilings.

