Reputation: 105
I have two arrays which are lists of points in 3D space. These arrays have different lengths.
np.shape(arr1) == (34709, 3)
np.shape(arr2) == (4835053, 3)
I have a function which can compute the Pythagorean distance between a single point in one array and all points in another, given periodic boundary conditions:
def pythag_periodic(array, point, dimensions):
delta = np.abs(array - point)
delta = np.where(delta > 0.5 * dimensions, delta - dimensions, delta)
return np.sqrt((delta ** 2).sum(axis=-1))
I am trying to apply this operation to all points in both arrays. I have a loop which calls this function recursively, but it is agonisingly slow.
for i in arr1:
pp.append(pythag_periodic(arr2, i, dimensions))
Any suggestions as to how I might speed this up would be much appreciated.
Upvotes: 1
Views: 162
Reputation: 803
Another cool option would be to exploit Numpy's broadcasting (via the None
keyword when indexing the arrays) and the super neat einsum
function to avoid the loop and perform the sum and square operations simultaneously, respectively.
Note however that this approach is slighty slower for small matrices, but once you get to sizes greater than 4000 elements or so it is much faster. Also, beware of running out of RAM as vectorization has this downside (although you are already storing the NxM
array in your code anyways).
import numpy as np
def pythag_periodic_vectorized(a1, a2):
delta = np.abs(a1[:,None,:] - a2[None,...])
delta = np.where(delta > 0.5 * a1.shape[1], delta - a1.shape[1], delta)
return np.sqrt(np.einsum("ijk,ijk->ij", delta, delta))
Upvotes: 1
Reputation: 781
You should use numba : https://numba.pydata.org/ (disclosed: I am not the author). It is a library that translates Python functions to optimized machine code at runtime. Thus, Numba-compiled numerical algorithms in Python can approach the speeds of C or FORTRAN.
To apply to your code is really simple. In a nutshell, import the library and then use the decorator. Besides, you have more options that can be relevant for you like Parallelize Your Algorithms (have a look to their website). For instance:
from numba import jit
@jit(nopython=True)
def pythag_periodic(array, point, dimensions):
delta = np.abs(array - point)
delta = np.where(delta > 0.5 * dimensions, delta - dimensions, delta)
return np.sqrt((delta ** 2).sum(axis=-1))
Upvotes: 1