Faydey
Faydey

Reputation: 737

Cython defining type for functions

I'm trying to make a cython-built slice-sampling library. A generic slice sampling library, where you supply a log-density, a starter value, and get a result. Working on the univariate model now. Based on the response here, I've come up with the following.

So i have a function defined in cSlice.pyx:

cdef double univariate_slice_sample(f_type_1 logd, double starter, 
                                        double increment_size = 0.5):
    some stuff
    return value

I have defined in cSlice.pxd:

cdef ctypedef double (*f_type_1)(double)
cdef double univariate_slice_sample(f_type_1 logd, double starter, 
                                               double increment_size = *)

where logd is a generic univariate log-density.

In my distribution file, let's say cDistribution.pyx, I have the following:

from cSlice cimport univariate_slice_sample, f_type_1

cdef double log_distribution(alpha_k, y_k, prior):
    some stuff
    return value

cdef double _sample_alpha_k_slice(
        double starter,
        double[:] y_k,
        Prior prior,
        double increment_size
        ):
    cdef f_type_1 f = lambda alpha_k: log_distribution(alpha_k), y_k, prior)
    return univariate_slice_sample(f, starter, increment_size)

cpdef double sample_alpha_k_slice(
        double starter,
        double[:] y_1,
        Prior prior,
        double increment_size = 0.5
        ):
    return _sample_alpha_1_slice(starter, y_1, prior, increment_size)

the wrapper because apparently lambda's aren't allowed in cpdef's.

When I try compiling the distribution file, I get the following:

cDistribution.pyx:289:22: Cannot convert Python object to 'f_type_1'

pointing at the cdef f_type_1 f = ... line.

I'm unsure of what else to do. I want this code to maintain C speed, and importantly not hit the GIL. Any ideas?

Upvotes: 2

Views: 94

Answers (1)

ead
ead

Reputation: 34326

You can jit a C-callback/wrapper for any Python function (cast to a pointer from a Python-object cannot done implicitly), how for example explained in this SO-post.

However, at its core the function will stay slow pure Python function. Numba gives you possibility to create real C-callbacks via a @cfunc. Here is a simplified example:

from numba import cfunc 
@cfunc("float64(float64)")
def id_(x):
    return x

and this is how it could be used:

%%cython
ctypedef double(*f_type)(double)

cdef void c_print_double(double x, f_type f):
    print(2.0*f(x))

import numba
expected_signature = numba.float64(numba.float64)
def print_double(double x,f):
    # check the signature of f:
    if not f._sig == expected_signature:
        raise TypeError("cfunc has not the right type")
    # it is not possible to cast a Python object to a pointer directly,
    # so we cast the address first to unsigned long long
    c_print_double(x, <f_type><unsigned long long int>(f.address))

And now:

print_double(1.0, id_)
# 2.0

We need to check the signature of the cfunc-object during the run time, otherwise the casting <f_type><unsigned long long int>(f.address) would "work" also for the functions with wrong signature - only to (possible) crash during the call or giving funny hard to debug errors. I'm just not sure that my method is the best though - even if it works:

...
@cfunc("float32(float32)")
def id3_(x):
    return x

print_double(1.0, id3_)
# TypeError: cfunc has not the right type

Upvotes: 1

Related Questions