Physicist
Physicist

Reputation: 3048

speeding up class with numpy array using cython

I have the following codes:

class _Particles:
    def __init__(self, num_particle, dim, fun, lower_bound, upper_bound):
        self.lower_bound = lower_bound   # np.array of shape (dim,)
        self.upper_bound = upper_bound   # np.array of shape (dim,)
        self.num_particle = num_particle   # a scalar
        self.dim = dim   # dimension, a scalar
        self.fun = fun   # a function

        self.pos = np.empty((num_particle,dim))
        self.val = np.empty(num_particle)
        self.randomize()


    def randomize(self):
        self.pos = np.random.rand(self.num_particle, self.dim)*(self.upper_bound\
                -self.lower_bound)+self.lower_bound
        self.val = self.fun(np.transpose(self.pos))
        self.best_idx = np.argmin(self.val)
        self.best_val = self.val[self.best_idx]
        self.best_pos = self.pos[self.best_idx]


    def move(self, displacement, idx='all', check_bound=True):
        if idx is 'all':
            self.pos += displacement
        elif isinstance(idx,(tuple,list,np.ndarray)):
            self.pos[idx] += displacement
        else:
            raise TypeError('Check the type of idx!',type(idx))

        self.pos = np.maximum(self.pos, self.lower_bound[np.newaxis,:])
        self.pos = np.minimum(self.pos, self.upper_bound[np.newaxis,:])
        self.val = self.fun(np.transpose(self.pos))
        self.best_idx = np.argmin(self.val)
        self.best_val = self.val[self.best_idx]
        self.best_pos = self.pos[self.best_idx]

I want to see if I can speed up the above code, and I'm thinking about using cython, but I am not sure if it's possible since it's mostly using numpy array and most executions are done with vectorization. I try something like:

# the .pyx file that will be compiled
cdef class _Particles(object):
    cdef int num_particle
    cdef int dim
    cdef fun
    cdef np.ndarray lower_bound
    cdef np.ndarray upper_bound
    cdef np.ndarray pos
    cdef np.ndarray val
    cdef int best_idx
    cdef double best_val
    cdef np.ndarray[np.float64_t, ndim=1] best_pos

    def __init__(self, int num_particle, int dim, fun,
                 np.ndarray lower_bound, np.ndarray upper_bound):
        self.num_particle = num_particle
        self.dim = dim
        self.fun = fun
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound

        self.pos = np.empty((num_particle,dim))
        self.val = np.empty(num_particle)
        self.randomize()

    def randomize(self):
        self.pos = npr.rand(self.num_particle,self.dim)*(self.upper_bound\
                -self.lower_bound)+self.lower_bound

        self.val = self.fun(np.transpose(self.pos))
        self.best_idx = np.argmin(self.val)
        self.best_val = self.val[self.best_idx]
        self.best_pos = self.pos[self.best_idx]

It it faster, but only by a bit which is kind of expected as it is still mostly python code. So are there any ways to speed up the above code using cython (or point me to some other completely methods)? In particular, how to speed up codes such as self.fun(self.pos), np.argmin(self.val)?

Thanks.

Upvotes: 1

Views: 214

Answers (1)

norok2
norok2

Reputation: 26886

Actually, there isn't much to optimize in the above code, I am afraid. To make argmin faster I would suggest you to get (or otherwise compile yourself) NumPy with multi-threaded support (or you could re-implement some multi-threaded argmin yourself).

As far as Cython goes, you get the real deal when you start using C types, but this is something I would not see a large improvement with the code you posted. That's mostly glue-code, no number-crunching involved there.

I would expect number-crunching to happen in the function fun, and this is probably the only place where actual manual optimization may make a difference, as long as it is not so easy to vectorize (read: there is a for or other manual looping). Then, I would start off with numba, which is a much simpler drop-in speed-up for your code, if it works. If it doesn't, it is probably appropriate to start looking into Cython.

Upvotes: 3

Related Questions