Reputation: 23
How would I go about calling a method of a Cython extension type from within a Numba jitted class? My minimal example below fails with the error I record below. How would I amend my minimal example to make it work?
Thanks for any help!!
I have a Cython module, shrubbery.pyx
:
cdef class Shrubbery:
cdef int height
def __init__(self, h):
self.height = h
def describe(self):
print('This shrubbery is', self.height, 'tall.')
I have a setup file setup.py
:
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
ext_modules = [Extension('shrubbery', ['shrubbery.pyx'])]
setup(
name='shrubbery',
cmdclass={'build_ext': build_ext},
ext_modules=ext_modules)
I compile shrubbery.pyx
into an extension type as usual (python setup.py build_ext --inplace
). Then I try to use Shrubbery
inside a numba jitted class as follows:
from shrubbery import Shrubbery
import numba as nb
spec = [('value', nb.int32)]
@nb.jitclass(spec)
class Bag(object):
def __init__(self, value):
self.value = value
def size(self):
return self.value
def mixed_class_method(self):
__shrubbery = Shrubbery(5)
__shrubbery.describe()
# pure numba class: works
_b = Bag(value=3)
print(_b.size())
# pure cython extension type: works
__shrubbery = Shrubbery(5)
__shrubbery.describe()
# mix of cython extension type and numba jitted class: fails
_b.mixed_class_method()
/Users/mg/anaconda/bin/python3 test.py
3
('This shrubbery is', 5, 'tall.')
Traceback (most recent call last):
File "test.py", line 28, in <module>
_b.mixed_class_method()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/jitclass/boxing.py", line 62, in wrapper
return method(*args, **kwargs)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
raise e
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
return self.compile(tuple(argtypes))
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 579, in compile
cres = self._compiler.compile(args, return_type)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 80, in compile
flags=flags, locals=self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 779, in compile_extra
return pipeline.compile_extra(func)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 362, in compile_extra
return self._compile_bytecode()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 738, in _compile_bytecode
return self._compile_core()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 725, in _compile_core
res = pm.run(self.status)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 248, in run
raise patched_exception
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 240, in run
stage()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 454, in stage_nopython_frontend
self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 881, in type_inference_stage
infer.propagate()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 846, in propagate
raise errors[0]
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 137, in propagate
constraint(typeinfer)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 415, in __call__
self.resolve(typeinfer, typevars, fnty)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 441, in resolve
sig = typeinfer.resolve_call(fnty, pos_args, kw_args, literals=literals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1115, in resolve_call
literals=literals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typing/context.py", line 204, in resolve_function_type
return func.get_call_type_with_literals(self, args, kws, literals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/types/functions.py", line 199, in get_call_type_with_literals
return self.get_call_type(context, args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/types/functions.py", line 193, in get_call_type
return self.template(context).apply(args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typing/templates.py", line 207, in apply
sig = generic(args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/jitclass/base.py", line 322, in generic
sig = disp_type.get_call_type(self.context, args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/types/functions.py", line 250, in get_call_type
template, pysig, args, kws = self.dispatcher.get_call_template(args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 269, in get_call_template
self.compile(tuple(args))
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 579, in compile
cres = self._compiler.compile(args, return_type)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 80, in compile
flags=flags, locals=self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 779, in compile_extra
return pipeline.compile_extra(func)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 362, in compile_extra
return self._compile_bytecode()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 738, in _compile_bytecode
return self._compile_core()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 725, in _compile_core
res = pm.run(self.status)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 248, in run
raise patched_exception
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 240, in run
stage()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 454, in stage_nopython_frontend
self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 880, in type_inference_stage
infer.build_constraint()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 802, in build_constraint
self.constrain_statement(inst)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 961, in constrain_statement
self.typeof_assign(inst)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1023, in typeof_assign
self.typeof_global(inst, inst.target, value)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1119, in typeof_global
typ = self.resolve_value_type(inst, gvar.value)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1042, in resolve_value_type
raise TypingError(msg, loc=inst.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Failed at nopython (nopython frontend)
Untyped global name 'Shrubbery': cannot determine Numba type of <class 'type'>
File "test.py", line 16
[1] During: resolving callee type: BoundFunction((<class 'numba.types.misc.ClassInstanceType'>, 'mixed_class_method') for instance.jitclass.Bag#7fef29835df8<value:int32>)
[2] During: typing of call at <string> (3)
Upvotes: 1
Views: 826
Reputation: 30909
This is mostly a response to your suggestion in the comments that CFFI functions can be made to work. This is true, but it's VERY limitted.
You can convert a Cython cdef
function to a CFFI function by going through a C function pointer. This conversion must take place in Cython. In order to work with Numba in nopython
mode the cdef
function must not take or return a Python object. This means that your Shrubbery
class is not possible. A simple function that only accepts/returns C types will work
from libc.stdint cimport uintptr_t
cdef void f(int x) nogil:
with gil:
print(x+1)
ctypedef void (*void_int_func_pointer)(int)
def get_cffi_f():
cdef void_int_func_pointer f_ptr = f
cdef uintptr_t f_ptr_int = <uintptr_t>f_ptr
from cffi import FFI
ffi = FFI()
return ffi.cast('void (*)(int)',f_ptr_int)
Within Python you call call get_cffi_f()
to get a CFFI wrapping of f
to pass to Numba functions. Note that I've declared the function as nogil
and captured the GIL within it - I'm not 100% sure if Numba releases the GIL so I'm doing this to be safe. It may not be necessary.
You can then pass those CFFI wrappings into Numba or access them as global variables:
import numba as nb
from cy import get_cffi_f
func_global = get_cffi_f()
@nb.jit(nopython=True)
def simple_func(func):
func(5)
func_global(6)
func(7)
@nb.jitclass([('value', nb.int32)])
class Bag(object):
def __init__(self,value):
self.value = value
def mixed_class_method(self,func):
func(self.value)
func_global(self.value-1)
simple_func(get_cffi_f())
Bag(3).mixed_class_method(get_cffi_f())
My view is that trying to make something like a Cython class work here is a lost cause.
There's probably other ways of achieving the same thing - you could get Cython to make headers with api
or public
and use those headers with CFFI.
Upvotes: 2
Reputation: 428
From the numba docs:
"All methods of a jitclass is compiled into nopython functions. The data of a jitclass instance is allocated on the heap as a C-compatible structure so that any compiled functions can have direct access to the underlying data, bypassing the interpreter."
As DavidW pointed out, Shrubbery is a Python type not a C type so you cannot use in a jitclass.
You could jit the individual methods though.
Upvotes: 0