Reputation: 11201
Following is a Cython code snippet currently used in scikit-learn binary trees,
# Some compound datatypes used below:
cdef struct NodeHeapData_t:
DTYPE_t val
ITYPE_t i1
ITYPE_t i2
# build the corresponding numpy dtype for NodeHeapData
cdef NodeHeapData_t nhd_tmp
NodeHeapData = np.asarray(<NodeHeapData_t[:1]>(&nhd_tmp)).dtype
(full source here)
The last line creates a numpy dtype from this Cython struct. I haven't been able to find much documentation about it, and in particular I don't understand why the slicing [:1]
is needed or what it does. More discussion can be found in scikit-learn#17228. Would anyone have ideas about it?
Upvotes: 6
Views: 925
Reputation: 34316
This is a clever but confusing trick!
The following code creates a cython-array of length 1, because the memory it uses (but not owns!) has exact one element.
cdef NodeHeapData_t nhd_tmp
<NodeHeapData_t[:1]>(&nhd_tmp)
Now, cython-array implements buffer-protocol and thus Cython has the machinery to create a format
-string which describes the type of the element it holds.
np.asarray
uses also the buffer-protocol and is able to construct a dtype
-object from the format
-string, which is provided by cython's array.
You can see the format-string via:
%%cython
import numpy as np
# Some compound datatypes used below:
cdef struct NodeHeapData_t:
double val
int i1
int i2
# build the corresponding numpy dtype for NodeHeapData
cdef NodeHeapData_t nhd_tmp
NodeHeapData = np.asarray(<NodeHeapData_t[:1]>(&nhd_tmp)).dtype
print("format string:",memoryview(<NodeHeapData_t[:1]>(&nhd_tmp)).format)
print(NodeHeapData )
which leads to
format string: T{d:val:i:i1:i:i2:}
[('val', '<f8'), ('i1', '<i4'), ('i2', '<i4')]
Off the top of my head, I cannot come up with a less confusing solution, other than creating dtype
-object by hand - which might get ugly for some data types on different platforms*, but should be straight forward for most cases.
*) np.int
is such a problematic case. It is easy to overlook that np.int
maps onto long
and not int
(confusing, isn't?).
For example
memoryview(np.zeros(1, dtype=np.int)).itemsize
evaluates to
long
in bytes on Windows).long
in bytes on Linux).Upvotes: 7