Reputation: 737
I'm trying to make a cython-built slice-sampling library. A generic slice sampling library, where you supply a log-density, a starter value, and get a result. Working on the univariate model now. Based on the response here, I've come up with the following.
So i have a function defined in cSlice.pyx:
cdef double univariate_slice_sample(f_type_1 logd, double starter,
double increment_size = 0.5):
some stuff
return value
I have defined in cSlice.pxd:
cdef ctypedef double (*f_type_1)(double)
cdef double univariate_slice_sample(f_type_1 logd, double starter,
double increment_size = *)
where logd is a generic univariate log-density.
In my distribution file, let's say cDistribution.pyx, I have the following:
from cSlice cimport univariate_slice_sample, f_type_1
cdef double log_distribution(alpha_k, y_k, prior):
some stuff
return value
cdef double _sample_alpha_k_slice(
double starter,
double[:] y_k,
Prior prior,
double increment_size
):
cdef f_type_1 f = lambda alpha_k: log_distribution(alpha_k), y_k, prior)
return univariate_slice_sample(f, starter, increment_size)
cpdef double sample_alpha_k_slice(
double starter,
double[:] y_1,
Prior prior,
double increment_size = 0.5
):
return _sample_alpha_1_slice(starter, y_1, prior, increment_size)
the wrapper because apparently lambda
's aren't allowed in cpdef
's.
When I try compiling the distribution file, I get the following:
cDistribution.pyx:289:22: Cannot convert Python object to 'f_type_1'
pointing at the cdef f_type_1 f = ...
line.
I'm unsure of what else to do. I want this code to maintain C speed, and importantly not hit the GIL. Any ideas?
Upvotes: 2
Views: 94
Reputation: 34326
You can jit a C-callback/wrapper for any Python function (cast to a pointer from a Python-object cannot done implicitly), how for example explained in this SO-post.
However, at its core the function will stay slow pure Python function. Numba gives you possibility to create real C-callbacks via a @cfunc
. Here is a simplified example:
from numba import cfunc
@cfunc("float64(float64)")
def id_(x):
return x
and this is how it could be used:
%%cython
ctypedef double(*f_type)(double)
cdef void c_print_double(double x, f_type f):
print(2.0*f(x))
import numba
expected_signature = numba.float64(numba.float64)
def print_double(double x,f):
# check the signature of f:
if not f._sig == expected_signature:
raise TypeError("cfunc has not the right type")
# it is not possible to cast a Python object to a pointer directly,
# so we cast the address first to unsigned long long
c_print_double(x, <f_type><unsigned long long int>(f.address))
And now:
print_double(1.0, id_)
# 2.0
We need to check the signature of the cfunc
-object during the run time, otherwise the casting <f_type><unsigned long long int>(f.address)
would "work" also for the functions with wrong signature - only to (possible) crash during the call or giving funny hard to debug errors. I'm just not sure that my method is the best though - even if it works:
...
@cfunc("float32(float32)")
def id3_(x):
return x
print_double(1.0, id3_)
# TypeError: cfunc has not the right type
Upvotes: 1