mykd
mykd

Reputation: 191

Speeding up vector distance calculation using Numba

Below are some of the functions I wrote for distance (square) calculation in 3-D toroidal geometry for a collection of particles in that 3-D space:

import itertools
import time
import numpy as np
import scipy
import numba
from numba import njit

@njit(cache=True)
def get_dr2(i=np.array([]),j=np.array([]),cellsize=np.array([])):
    k=np.zeros(3,dtype=np.float64)
    dr2=0.0
    for idx in numba.prange(cellsize.shape[0]):
        k[idx] = (j[idx]-i[idx])-cellsize[idx]*np.rint((j[idx]-i[idx])/cellsize[idx])
        dr2+=k[idx]**2
    return dr2


@numba.guvectorize(["void(float64[:],float64[:],float64[:],float64[:])"],
             "(m),(m),(m)->()",nopython=True,cache=True)
def get_dr2_vec(i,j,cellsize,dr2):
    dr2[:]=0.0
    k=np.zeros(3,dtype=np.float64)
    for idx in numba.prange(cellsize.shape[0]):
        k[idx] = (j[idx]-i[idx])-cellsize[idx]*np.rint((j[idx]-i[idx])/cellsize[idx])
        dr2[0]+=k[idx]**2


@njit(cache=True)
def pair_vec_gen(pIList=np.array([[]]),pJList=np.array([[]])):
    assert pIList.shape[1] == pJList.shape[1]
    vecI=np.zeros((pIList.shape[0]*pJList.shape[0],pIList.shape[1]))
    vecJ=np.zeros_like(vecI)
    for i in numba.prange(pIList.shape[0]):
        for j in numba.prange(pJList.shape[0]):
            for k in numba.prange(pIList.shape[1]):
                vecI[j+pJList.shape[0]*i][k]=pIList[i][k]
                vecJ[j+pJList.shape[0]*i][k]=pJList[j][k]

    return vecI,vecJ


@njit(cache=True)
def pair_vec_dist(pIList=np.array([[]]),pJList=np.array([[]]),cellsize=np.array([])):
    assert pIList.shape[1] == pJList.shape[1]
    vecI=np.zeros((pIList.shape[0]*pJList.shape[0],pIList.shape[1]))
    vecJ=np.zeros_like(vecI)
    r2List=np.zeros(vecI.shape[0])
    for i in numba.prange(pIList.shape[0]):
        for j in numba.prange(pJList.shape[0]):
            for k in numba.prange(pIList.shape[1]):
                vecI[j+pJList.shape[0]*i][k]=pIList[i][k]
                vecJ[j+pJList.shape[0]*i][k]=pJList[j][k]
    r2List=get_dr2_vec2(vecI,vecJ,cellsize)
    return r2List


@njit(cache=True)
def get_dr2_vec2(i=np.array([[]]),j=np.array([[]]),cellsize=np.array([])):
    dr2=np.zeros(i.shape[0],dtype=np.float64)
    k=np.zeros(i.shape[1],dtype=np.float64)
    for m in numba.prange(i.shape[0]):
        for n in numba.prange(i.shape[1]):
            k[n] = (j[m,n]-i[m,n])-cellsize[n]*np.rint((j[m,n]-i[m,n])/cellsize[n])
            dr2[m]+=k[n]**2
    return dr2


def pair_dist_calculator_cdist(pIList=np.array([[]]),pJList=np.array([[]]),cellsize=np.array([])):
    assert pIList.shape[1] == pJList.shape[1]
    r2List = (scipy.spatial.distance.cdist(pIList, pJList, metric=get_dr2_wrapper(cellsize=cellsize))).flatten()
    return np.array(r2List).flatten()

def get_dr2_wrapper(cellsize=np.array([])):
    return lambda u, v: get_dr2(u,v,cellsize)


frames=50
timedata=np.zeros((5,frames),dtype=np.float64)
N, dim = 100, 3  # 100 particles in 3D
cellsize=np.array([26.4,19.4,102.4])
for i in range(frames):
    
    print("\rIter {}".format(i),end='')
    vec = np.random.random((N, dim))
    
    rList1=[];rList2=[];rList3=[];rList4=[];rList5=[]
   
    #method 1
    #print("method 1")
    start = time.perf_counter()
    for (pI, pJ) in itertools.product(vec, vec):
        rList1.append(get_dr2(pI,pJ,cellsize))
    end =time.perf_counter()
    timedata[0,i]=(end-start)
    
    #method 2
    #print("method 2")
    pIvec=[];pJvec=[];rList2=[]
    start = time.perf_counter()
    for (pI, pJ) in itertools.product(vec, vec):
        pIvec.append(pI)
        pJvec.append(pJ)
    rList2=get_dr2_vec(np.array(pIvec),np.array(pJvec),cellsize)
    end =time.perf_counter()
    timedata[1,i]=(end-start)
    
    #method 3
    #print("method 3")
    start = time.perf_counter()
    rList3=get_dr2_vec(*pair_vec_gen(vec,vec),cellsize)
    end =time.perf_counter()
    timedata[2,i]=(end-start)
    
    #method 4
    #print("method 4")    
    start = time.perf_counter()
    rList4=pair_vec_dist(vec,vec,cellsize)
    end =time.perf_counter()
    timedata[3,i]=(end-start)

    #method 5
    #print("method 5")
    #start = time.perf_counter()
    #rList5=pair_dist_calculator_cdist(np.array(pIvec),np.array(pJvec),cellsize)
    #end =time.perf_counter()
    #timedata[4,i]=(end-start)
    
    assert (rList1 == rList2).all()
    assert (rList2 == rList3).all()
    assert (rList3 == rList4).all()
    #assert rList4 == rList5
       

print("\n")

for i in range(4):
    print("Method {} Average time {:.3g}s \u00B1 {:.3g}s".format(i+1,np.mean(timedata[i,1:]),np.std(timedata[i,1:])))

exit()

The essential idea is that at a particular time you have a snapshot of the particles or frame which contains the position of the particles. Now we can calculate all the distances between the particles we can use the following approaches:

  1. Calculate distance between points iteratively in pure python; passing the combination of the position of the two particles one by one via Numba.
  2. Create an iteration list (in pure python) beforehand and pass the whole list to a Numba @guvectorize function
  3. Do (2) but all steps in Numba
  4. Integrate all step in (3) to a simple Numba function
  5. (optional) parse the positions to scipy.spatial.distance.cdist with the distance function as the distance metric.

For 50 frames containing 100 particles we have the respective times (frames, N = 50, 100):

Method 1 Average time 0.017s ± 0.00555s
Method 2 Average time 0.0181s ± 0.00573s
Method 3 Average time 0.00182s ± 0.000944s
Method 4 Average time 0.000485s ± 0.000348s

For 50 frames containing 1000 particles we have the respective times (frames, N = 50, 1000):

Method 1 Average time 2.11s ± 0.977s
Method 2 Average time 2.42s ± 0.859s
Method 3 Average time 0.349s ± 0.12s
Method 4 Average time 0.0694s ± 0.022s

and for 1000 frames containing 100 particles we have the respective times (frames, N = 1000, 100):

Method 1 Average time 0.0244s ± 0.0166s
Method 2 Average time 0.0288s ± 0.0254s
Method 3 Average time 0.00258s ± 0.00231s
Method 4 Average time 0.000636s ± 0.00086s

(All the time shown above are after removing the contribution from the first iteration) Method 5 simply fails due to memory requirements and is much slower in comparison to any other method

Given the above dataset, I tend to prefer Method 4 though I am a bit concerned about the average time increase when I increase frames from 50 to 1000. Is there any further optimizations I can do in these implementations or if someone has ideas for much faster and memory conscious implementations? Any suggestions are welcome.

Update

Based on Jerome's answer the modified function is now:

@njit(cache=True,parallel=True)
def pair_vec_dist(pIList=np.array([[]]),pJList=np.array([[]]),cellsize=np.array([])):
    assert pIList.shape[1] == pJList.shape[1]
    assert cellsize.size == 3
    dr2=np.zeros(pIList.shape[0]*pJList.shape[0],dtype=np.float64)
    inv_cellsize = 1.0 / cellsize
    for i in numba.prange(pIList.shape[0]):
        for j in range(pJList.shape[0]):
            offset = j + pJList.shape[0] * i
            xdist = pJList[j,0]-pIList[i,0]
            ydist = pJList[j,1]-pIList[i,1]
            zdist = pJList[j,2]-pIList[i,2]

            xk =  xdist-cellsize[0]*np.rint(xdist*inv_cellsize[0])
            yk =  ydist-cellsize[1]*np.rint(ydist*inv_cellsize[1])
            zk =  zdist-cellsize[2]*np.rint(zdist*inv_cellsize[2])

            dr2[offset] = xk**2+yk**2+zk**2
    return dr2

As Jerome pointed out that a very simple optimization could be running the loops through just the "lower half of the symmetric matrix" the distance calculation creates, though in a realistic situation I might have vector lists as pI and pJ where pI is a subset of pJ, which complicates this situation. Either I have to create two separate functions and control them via a wrapper function or somehow manage that in one single function. If there are any suggestions on how to do so that would be really helpful.

Update 2

I should clarify the problem furthermore. In this code I am trying to calculate distance between all points in a frame/snapshot, which is used further for pair distance distribution analysis. But in some cases we might want to focus on a subset of coordinates in a frame and calculate the distribution from their perspective. In such a case we select this subset smallVec from a pool of all coordinates vec (such that smallVec +restOfVec = vec) and calculate pair_vec_dist(smallVec,vec) instead of pair_vec_dist(vec,vec). For this calculation one can use list(pair_vec_dist(smallVec,smallVec)).append(pair_vec_dist(smallVec,restOfVec). Based on the discussion with Jerome, I modified my function as:

@njit(cache=True,parallel=True)
def pair_vec_dist_cmb(pIList=np.array([[]]),pJList=np.array([[]]),cellsize=np.array([]),is_sq=True,is_nonsq=True):
    assert pIList.shape[1] == pJList.shape[1]
    assert cellsize.size == 3

    dr2_1=0; dr2_2=0

    dr2_1=int(0.5*pIList.shape[0]*(pIList.shape[0]+1))

    if is_nonsq:
       dr2_2=int(pIList.shape[0]*pJList.shape[0])

    dr2 = np.zeros((dr2_1+dr2_2),dtype=np.float64)
    inv_cellsize = 1.0 / cellsize
    for j in numba.prange(0,pIList.shape[0],1):

        if is_sq:

           for i in range(j,pIList.shape[0],1):
               index_1 = int(0.5*i*(i+1)+j)
               xdist = pIList[j,0]-pIList[i,0]
               ydist = pIList[j,1]-pIList[i,1]
               zdist = pIList[j,2]-pIList[i,2]

               xk =  xdist-cellsize[0]*np.rint(xdist*inv_cellsize[0])
               yk =  ydist-cellsize[1]*np.rint(ydist*inv_cellsize[1])
               zk =  zdist-cellsize[2]*np.rint(zdist*inv_cellsize[2])

               dr2[index_1] = xk**2+yk**2+zk**2

        if is_nonsq:

           for j in range(pJList.shape[0]):
               index_2 = dr2_1+ j + pJList.shape[0] * i
               xdist = pJList[j,0]-pIList[i,0]
               ydist = pJList[j,1]-pIList[i,1]
               zdist = pJList[j,2]-pIList[i,2]

               xk =  xdist-cellsize[0]*np.rint(xdist*inv_cellsize[0])
               yk =  ydist-cellsize[1]*np.rint(ydist*inv_cellsize[1])
               zk =  zdist-cellsize[2]*np.rint(zdist*inv_cellsize[2])

               dr2[index_2] = xk**2+yk**2+zk**2

    return dr2

Where pI (size: (N,3)) is the subset of pJ (size (M,3). In this code we subdivide the calculation into two sections: pair distance between pI-pI, which is symmetric and hence we can calculate only the lower triangular matrix i.e. N(N-1)/2 unique values. The other section is pI-pJ distances where we have to go through M(M-N) unique values. To further optimize the function, I have two additional changes:

  1. Combining the outer loop for both sections. In order to do so I am now iterating over the upper triangular matrix which translates to N(N+1)/2 values. One can also implement an if check to see if coordinates are identical, though I am not sure how much time it would save.
  2. To avoid appending the results from the two section together, I am predefining and partitioning the returned array by length.

A further assumption I have made is that time needed for partitioning vec into smallVec and restOfVec is negligent with respect to the pair distance calculation. Obviously, if wrong, one might need to rethink another optimization pathway.

The resultant function is 1.5 times faster than the previous function. I am looking to further optimize the function, but I am very new to loop tilling and other advanced optimizations, so if you have any suggestions, please let me know.

Update 3 So I figured that I should focus on making the function more optimized in terms of serial calculations as I might simply use Dask or multiprocessing to implement to work on multiple sections of an input collection of frames. So the reference function now is:

@njit(cache=True,parallel=False, fastmath=True, boundscheck=False, nogil=True)
def pair_vec_dist_test(pIList,pJList,cellsize):

    _I=pIList.shape[0]
    _J=pJList.shape[0]

    dr2 = np.empty(int(_I*_J),dtype=np.float32)
    inv_cellsize = 1.0 / cellsize

    for i in numba.prange(pIList.shape[0]):
        for j in range(pJList.shape[0]):
            index = j + pJList.shape[0] * i
            xdist = pJList[j,0]-pIList[i,0]
            ydist = pJList[j,1]-pIList[i,1]
            zdist = pJList[j,2]-pIList[i,2]
            xk =  xdist-cellsize[0]*np.rint(xdist*inv_cellsize[0])
            yk =  ydist-cellsize[1]*np.rint(ydist*inv_cellsize[1])
            zk =  zdist-cellsize[2]*np.rint(zdist*inv_cellsize[2])
            dr2[index] = xk**2+yk**2+zk**2

    return dr2

Going back to the main problem while ignoring the symmetry aspect, I tried to further optimize the distance function as:

@njit(cache=True,parallel=False, fastmath=True, boundscheck=False, nogil=True)
def pair_vec_dist_test_v2(pIList,pJList,cellsize):

    _I=pIList.shape[0]
    _J=pJList.shape[0]

    dr2 = np.empty(int(_I*_J),dtype=np.float32)
    inv_cellsize = 1.0 / cellsize

    tile=32
    
    
    for ii in range(0,_I,tile):
        for jj in range(0,_J,tile):
            for i in range(ii,min(_I,ii+tile)):
                for j in range(jj,min(_J,jj+tile)):
                    index = j + _J * i
                    xdist = pJList[j,0]-pIList[i,0]
                    ydist = pJList[j,1]-pIList[i,1]
                    zdist = pJList[j,2]-pIList[i,2]
                    xk =  xdist-cellsize[0]*np.rint(xdist*inv_cellsize[0])
                    yk =  ydist-cellsize[1]*np.rint(ydist*inv_cellsize[1])
                    zk =  zdist-cellsize[2]*np.rint(zdist*inv_cellsize[2])
                    dr2[index] = xk**2+yk**2+zk**2

    return dr2

which is essentially tiling up the two vector arrays. However I couldn't get any speedup as the exec time for both functions are roughly the same. I also thought about working with the transpose of the vector arrays, but I couldn't figure out how to align them in a loop when the vector lengths are not a multiple of tile length. Does anyone has any further suggestions or ideas on how to procced?

Edit: Another failed trial

@njit(cache=True,parallel=False, fastmath=True, boundscheck=False, nogil=True)
def pair_vec_dist_test_v3(pIList,pJList,cellsize):
    
    inv_cellsize = 1.0 / cellsize
    
    tile=32
  
    _I=pIList.shape[0]
    _J=pJList.shape[0]

    vecI=np.empty((_I+2*tile,3),dtype=np.float64)     # for rolling effect
    vecJ=np.empty((_J+2*tile,3),dtype=np.float64)     # for rolling effect

    vecI_mask=np.ones((_I+2*tile),dtype=np.uint8)
    vecJ_mask=np.ones((_J+2*tile),dtype=np.uint8)

    vecI[:_I]=pIList
    vecJ[:_J]=pJList

    vecI[_I:]=0.
    vecJ[_J:]=0.

    vecI_mask[_I:]=0
    vecI_mask[_J:]=0

    #print(vecI,vecJ)

    ILim=_I+(tile-_I%tile)
    JLim=_J+(tile-_J%tile)
    
    dr2 = np.empty((ILim*JLim),dtype=np.float64)

    vecI=vecI.T
    vecJ=vecJ.T

    for ii in range(ILim):
        for jj in range(0,JLim,tile):
            index = jj + JLim*ii
            #print(ii,jj,index)
            mask  = np.multiply(vecJ_mask[jj:jj+tile],vecI_mask[ii:ii+tile])
            xdist = vecJ[0,jj:jj+tile]-vecI[0,ii:ii+tile]
            ydist = vecJ[1,jj:jj+tile]-vecI[1,ii:ii+tile]
            zdist = vecJ[2,jj:jj+tile]-vecI[2,ii:ii+tile]
            xk =  xdist-cellsize[0]*np.rint(xdist*inv_cellsize[0])
            yk =  ydist-cellsize[1]*np.rint(ydist*inv_cellsize[1])
            zk =  zdist-cellsize[2]*np.rint(zdist*inv_cellsize[2])
            arr = xk**2+yk**2+zk**2
            dr2[index:index+tile] = np.multiply(arr,mask)


    return dr2

Upvotes: 2

Views: 276

Answers (1)

Jérôme Richard
Jérôme Richard

Reputation: 50688

First things first: there are races conditions in your current code. This basically means the produced results can be corrupted (and it also impact performance). In practice, this causes an undefined behaviour. For example, k[n] is read by multiple thread in get_dr2_vec2. One need to be very careful when using prange. In this case, the race condition can be removed by just not using a temporary array which is not really useful and not using prange in the inner loop due to dr2[m] being updated (updating it from multiple threads also cause a race condition).

Moreover, prange is often not practically useful when parallel=True is not set in the Numba decorator. Indeed, the current functions are not parallel since this flag is missing.

Finally, you can merge the function pair_vec_dist and get_dr2_vec2 and the internal loops so to avoid creating and filling large temporary arrays. Indeed, the RAM throughput is pretty small nowadays compared to the computing power of modern processor. This gap is getting bigger since the last two decades. This effect is called the "memory wall" and it is not expected to disappear any time soon. Codes less memory-bound generally tends to be faster and scale better.

Here is the resulting code:

@njit(cache=True, parallel=True)
def pair_vec_dist(pIList=np.array([[]]),pJList=np.array([[]]),cellsize=np.array([])):
    assert pIList.shape[1] == pJList.shape[1]
    dr2=np.zeros(pIList.shape[0]*pJList.shape[0],dtype=np.float64)
    inv_cellsize = 1.0 / cellsize
    for i in numba.prange(pIList.shape[0]):
        for j in range(pJList.shape[0]):
            offset = j + pJList.shape[0] * i
            for k in range(pIList.shape[1]):
                tmp = pJList[j,k]-pIList[i,k]
                k = tmp-cellsize[k]*np.rint(tmp*inv_cellsize[k])
                dr2[offset] += k**2
    return dr2

It is 11 times faster with frames=50 and N=1000 on my 6-core machine (i5-9600KF).

The code can be optimized further. For example, dr2 is a flatten symmetric square matrix, so only the upper-right part needs to be computed and the bottom-left part can just be copied. Note that to do that efficiently in parallel, the work needs to be balanced between the thread (otherwise, the slowest will not be faster and will be the bottleneck). One can also generate an optimized version of the function only supporting cellsize.size == 3. Moreover, one can use register tiling so to make the code more cache-friendly. Finally, one can transpose the input so the layout is more SIMD-friendly (this certainly require the loop to be manually unrolled and the register tiling optimization to be done before).

Upvotes: 3

Related Questions