Ipse Lium
Ipse Lium

Reputation: 1080

Polymorphism with cython extesion types

I have a cython extension type that I want to make more general. One of the attributes of this extension type is a double and I want it to be a memoryview (double[::1]) when needed.

Here is a simple example :

import numpy as np
cimport numpy as np
cimport cython

cdef class Test:

    cdef bint numeric
    cdef double du  

    def __init__(self, bint numeric):

        self.numeric = numeric

        if self.numeric:
            self.du = 1

        else:
            self.du = np.ones(10)

    def disp(self)
        print(self.du)

Test(True).disp()   # returns 1
Test(False).disp()  # gives of course an error

I tried to subclass Test changing du type to double[::1] and implementing a new __init__ but it seems that we can't override class attributes of extension types. Even if it worked, it wouldn't be satisfactory because I don't really want to have one extension type for each case.

The best would be that my extension type directly handle both cases (scalar du and memoryview du).

Is there a way to do this with Cython ?

Upvotes: 2

Views: 422

Answers (1)

BlueSheepToken
BlueSheepToken

Reputation: 6159

Unfortunately, you cannot use fused_type as attributes type. You can have two options here:

You could try to use the memory adress of the variable you want to call, and cast it when needed (everything is explained here.) Unfortunately, I did not succeed at making it work with typed memory views.

Or you can use your defined attribute numeric to call the appropriate method:

import numpy as np
cimport numpy as np
cimport cython

cdef class Test:

    cdef bint numeric
    cdef double du_numeric
    cdef double[:] du_mem_view

    def __init__(self, bint numeric):

        self.numeric = numeric

        if self.numeric:
            self.du_numeric = 1

        else:
            self.du_mem_view = np.ones(10)

    def disp(self):
        if self.numeric:
            print(self.du_numeric)
        else:
            print(self.du_numeric_mem_view)

Test(True).disp()   # returns 1
Test(False).disp()  # Does not give an error anymore !

Upvotes: 1

Related Questions