msm1089
msm1089

Reputation: 1504

Cython error thrown unexpectedly for buffer types which ARE function local variables

Just learning Cython, and ran into an error I don't understand, well I don't understand why it is being thrown in the function below. The pure Python version works as expected.

import numpy as np
cimport numpy as np

ctypedef np.int8_t DTYPE_int8
ctypedef np.int64_t DTYPE_int64

cdef int fitness(np.ndarray[DTYPE_int8, ndim=1] c):
    
    cdef np.ndarray[DTYPE_int8, ndim=1] a
    cdef np.ndarray[DTYPE_int8, ndim=1] d1_cnts
    cdef np.ndarray[DTYPE_int8, ndim=1] d2_cnts
    
    a = np.arange(N, dtype=np.int8)
    d1_cnts = np.unique(c - a, return_counts=True)[1]
    d2_cnts = np.unique(c + a, return_counts=True)[1]
    d1_cnts = sum(sum(range(n)) for n in d1_cnts)
    d2_cnts = sum(sum(range(n)) for n in d2_cnts)
    return d1_cnts + d2_cnts

pop = np.arange(N, dtype=np.int8)
print(fitness(pop))

I get the error "Buffer types only allowed as function local variables" for d1_cnts and d2_cnts. Both of these variables ARE function local variables (right?), so why does this error occur?

Upvotes: 0

Views: 93

Answers (1)

DavidW
DavidW

Reputation: 30888

They're captured in the generator expression in:

(sum(range(n)) for n in d1_cnts)

In this case the generator expression is immediately used by sum but Cython is worried about the general case where you might return the generator expression (with the captured variables).

Your two options are:

  1. Use typed memoryviews, which are less restricted
  2. Switch to a sum over a list comprehension, which doesn't capture variables.

I think Cython should end up optimizing out the generator expression fairly efficiently, but the error message comes at an earlier stage, before it realises that this is an option.


Also sum(range(n)) should be replaceable with the triangle number formula ((n-1)*n/2 I think), but that's a separate optimization

Upvotes: 1

Related Questions