how to create a list of numpy arrays in jitclass

I want to create a jitclass, that will store some numpy arrays. And i don't know exactly how many of them. So I want to create a list of numpy arrays. I am confused in numba types, but found some strange solution. This runs normal.

import numba
from numba import types, typed, typeof
from numba.experimental import jitclass
import numpy as np


spec = [
    ('test', typeof(typed.List.empty_list(numba.int64[:])))
]

@jitclass(spec)
class myLIST(object):
    def __init__ (self, haha=typed.List.empty_list(numba.int64[:])):
        self.test = haha
        self.test.append(np.asarray([0]))

    def dump(self):
        self.test.append(np.asarray([1]))
        print(self.test)

a = myLIST()
a.dump()

but when I remove redundant variable, it fails.

spec = [
    ('test', typeof(typed.List.empty_list(numba.int64[:])))
]

@jitclass(spec)
class myLIST(object):
    def __init__ (self):
        self.test = typed.List.empty_list(numba.int64[:])
        self.test.append(np.asarray([0]))

    def dump(self):
        self.test.append(np.asarray([1]))
        print(self.test)

a = myLIST()
a.dump()

Why this happens?

Upvotes: 1

Views: 1038

Answers (1)

aerobiomat
aerobiomat

Reputation: 3437

It seems that declaring an array type as nb.int64[:] doesn't provide enough information to create the class unless you create an instance (the default value for haha) that Numba can use to infer the type.

Instead, you can declare:

int_vector = nb.types.Array(dtype=nb.int64, ndim=1, layout="C")
spec = [('test', nb.typeof(nb.typed.List.empty_list(int_vector)))]

Or shorter:

int_vector = nb.types.Array(dtype=nb.int64, ndim=1, layout="C")
spec = [('test', nb.types.ListType(int_vector))]

Or, if you can use type annotations:

int_vector = nb.types.Array(dtype=nb.int64, ndim=1, layout="C")

@nb.experimental.jitclass
class my_list:

    test: nb.types.ListType(int_vector)

    def __init__(self):
        self.test = nb.typed.List.empty_list(int_vector)
        self.test.append(np.array([0]))

    def dump(self):
        self.test.append(np.array([1]))
        print(self.test)

a = my_list()
a.dump()

Upvotes: 2

Related Questions