Yuval Langer
Yuval Langer

Reputation: 649

How do I define a numba class with ndarray arguments and/or output?

I tried this:

import numpy as np
import numba

@numba.jit
class foo(object):
    @numba.void(numba.int32)
    def __init__(self, somenum):
        self.somenumarray = np.arange(somenum)

    @numba.jit('f8[:](f8[:])')
    def somemethod1(self, a):
        return self.somenumarray + a

using the @numba.double[:](numba.double[:]) method decorator results in an error.

Upvotes: 2

Views: 1005

Answers (1)

Yuval Langer
Yuval Langer

Reputation: 649

This can be done using numba.FunctionType

import numpy as np
import numba

bar = numba.FunctionType(return_type=numba.f8[:], args=[numba.f8[:]])

@numba.jit
class foo(object):
    @numba.FunctionType(return_type=numba.void, args=[numba.int32])
    def __init__(self, somenum):
        self.somenumarray = np.arange(somenum)

    @bar
    def somemethod1(self, a):
        return self.somenumarray + a

You can later do this:

quux = foo(3)
quux.somemethod1(np.arange(3))

Upvotes: 1

Related Questions