rth
rth

Reputation: 11201

Creating a numpy datatype from Cython struct

Following is a Cython code snippet currently used in scikit-learn binary trees,

  # Some compound datatypes used below:
  cdef struct NodeHeapData_t:
      DTYPE_t val
      ITYPE_t i1
      ITYPE_t i2

  # build the corresponding numpy dtype for NodeHeapData
  cdef NodeHeapData_t nhd_tmp
  NodeHeapData = np.asarray(<NodeHeapData_t[:1]>(&nhd_tmp)).dtype

(full source here)

The last line creates a numpy dtype from this Cython struct. I haven't been able to find much documentation about it, and in particular I don't understand why the slicing [:1] is needed or what it does. More discussion can be found in scikit-learn#17228. Would anyone have ideas about it?

Upvotes: 6

Views: 925

Answers (1)

ead
ead

Reputation: 34316

This is a clever but confusing trick!

The following code creates a cython-array of length 1, because the memory it uses (but not owns!) has exact one element.

cdef NodeHeapData_t nhd_tmp
<NodeHeapData_t[:1]>(&nhd_tmp)

Now, cython-array implements buffer-protocol and thus Cython has the machinery to create a format-string which describes the type of the element it holds.

np.asarray uses also the buffer-protocol and is able to construct a dtype-object from the format-string, which is provided by cython's array.

You can see the format-string via:

%%cython
import numpy as np

# Some compound datatypes used below:
cdef struct NodeHeapData_t:
  double val
  int i1
  int i2

# build the corresponding numpy dtype for NodeHeapData
cdef NodeHeapData_t nhd_tmp
NodeHeapData = np.asarray(<NodeHeapData_t[:1]>(&nhd_tmp)).dtype

print("format string:",memoryview(<NodeHeapData_t[:1]>(&nhd_tmp)).format)
print(NodeHeapData )

which leads to

format string: T{d:val:i:i1:i:i2:}
[('val', '<f8'), ('i1', '<i4'), ('i2', '<i4')]

Off the top of my head, I cannot come up with a less confusing solution, other than creating dtype-object by hand - which might get ugly for some data types on different platforms*, but should be straight forward for most cases.


*) np.int is such a problematic case. It is easy to overlook that np.int maps onto long and not int (confusing, isn't?).

For example

memoryview(np.zeros(1, dtype=np.int)).itemsize

evaluates to

  • On Windows: 4 (size of long in bytes on Windows).
  • On Linux: 8 (size of long in bytes on Linux).

Upvotes: 7

Related Questions