Julius
Julius

Reputation: 765

Pass cython functions via python interface

Can a cdef Cython function be passed to another (python def) cython function from a Python script?

Minimal example:

test_module.pyx

cpdef min_arg(f, int N):
    cdef double x = 100000.
    cdef int best_i = -1

    for i in range(N):
        if f(i) < x:
            x = f(i)
            best_i = i
    return best_i

def py_f(x):
    return (x-5)**2

cdef public api double cy_f(double x):
    return (x-5)**2

test.py

import pyximport; pyximport.install()
import testmodule

testmodule.min_arg(testmodule.py_f, 100)

This works well, but I want to be able to also do

testmodule.min_arg(testmodule.cy_f, 100)

from a test.py, to have cython's speed (no Python overhead for each f(i) call). But obviously, Python doesn't know about cy_f, because it's not def or cpdef declared.

I was hoping something like this existed:

from scipy import LowLevelCallable
cy_f = LowLevelCallable.from_cython(testmodule, 'cy_f')
testmodule.min_arg(cy_f, 100)

But this gives TypeError: 'LowLevelCallable' object is not callable.

Thank you in advance.

Upvotes: 1

Views: 551

Answers (1)

Pierre de Buyl
Pierre de Buyl

Reputation: 7293

The LowLevelCallable is a class of functions that must be accepted by the underlying Python module. This work has been done for a few modules, including the quadrature routine scipy.integrate.quad

If you wish to use the same wrapping method, you must either go through the SciPy routines that make use of it, such as scipy.ndimage.generic_filter1d or scipy.integrate.quad. The code sits in compiled extensions, however.

The alternative, if your problem is reasonably well defined for the callback, is to implement this yourself. I have done this in one of my codes, so I post the link for simplicity:

  1. In a .pxd file, I define the interface cyfunc_d_d: https://github.com/pdebuyl/skl1/blob/master/skl1/core.pxd
  2. I can re-use this interface in the "base" cython module https://github.com/pdebuyl/skl1/blob/master/skl1/euler.pyx and also in a "user-defined" module.

The final code makes plain "cython-cython" calls while allowing the passing of objects at the Cython level

I adapted the code to your problem:

  1. test_interface.pxd

    cdef class cyfunc:                                                                                                                         
        cpdef double f(self, double x)                                                                                                         
    
    cdef class pyfunc(cyfunc):                                                                                                                 
        cdef object py_f                                                                                                                       
        cpdef double f(self, double x)                                                                                                         
    
  2. test_interface.pyx

    cdef class cyfunc:
        cpdef double f(self, double x):
            return 0
        def __cinit__(self):
            pass
    
    
    cdef class pyfunc(cyfunc):
        cpdef double f(self, double x):
            return self.py_f(x)
        def __init__(self, f):
            self.py_f = f
    
  3. setup.py

    from setuptools import setup, Extension                                                                                                    
    from Cython.Build import cythonize                                                                                                         
    
    setup(                                                                                                                                     
        ext_modules=cythonize((Extension('test_interface', ["test_interface.pyx"]),                                                            
                              Extension('test_module', ["test_module.pyx"]))                                                                   
                          )                                                                                                                    
    )                                                                                                                                          
    
  4. test_module.pyx

    from test_interface cimport cyfunc, pyfunc                                                                                                 
    
    cpdef min_arg(f, int N):                                                                                                                   
        cdef double x = 100000.                                                                                                                
        cdef int best_i = -1                                                                                                                   
        cdef int i                                                                                                                             
        cdef double current_value                                                                                                              
    
        cdef cyfunc py_f                                                                                                                       
    
        if isinstance(f, cyfunc):                                                                                                              
            py_f = f                                                                                                                           
            print('cyfunc')                                                                                                                    
        elif callable(f):                                                                                                                      
            py_f = pyfunc(f)                                                                                                                   
            print('no cyfunc')                                                                                                                 
        else:                                                                                                                                  
            raise ValueError("f should be a callable or a cyfunc")                                                                             
    
        for i in range(N):                                                                                                                     
            current_value = py_f.f(i)                                                                                                          
            if current_value < x:                                                                                                              
                x = current_value                                                                                                              
                best_i = i                                                                                                                     
        return best_i                                                                                                                          
    
    def py_f(x):                                                                                                                               
        return (x-5)**2                                                                                                                        
    
    cdef class cy_f(cyfunc):                                                                                                                   
        cpdef double f(self, double x):                                                                                                        
            return (x-5)**2                                                                                                                    
    

To use:

python3 setup.py build_ext --inplace
python3 -c 'import test_module ; print(test_module.min_arg(test_module.cy_f(), 10))'
python3 -c 'import test_module ; print(test_module.min_arg(test_module.py_f, 10))'

Upvotes: 1

Related Questions