toelzstudios
toelzstudios

Reputation: 11

Preserve numpy ndarray subclass when multiplying with scipy csr_matrix

I would like to preserve the type of a ndarray sublcass when doing matrix-vector multiplication with a scipy csr_matrix.

My subclass is

import numpy as np

class FlattenedMeshVector(np.ndarray):

    __array_priority__ = 15
    
    def __new__(cls, input_array):
        obj = np.asarray(input_array).view(cls)
        return obj

            
    def __array_finalize__(self, obj):
        if obj is None: return
        self.nx = getattr(obj, 'nx', None)
        
    def __array_wrap__(self, out_arr, context=None):
        return super().__array_wrap__(self, out_arr, context)

Since I set __array_priority__ = 15 matrix-vector multiplication with a regular ndarray (which has __array_priority__ = 0) nicely preserves the sublcass type

>>> a = FlattenedMeshVector([1,1,1,1])
>>> id_mat = np.diag(np.ones(4))
>>> id_mat.dot(a)
FlattenedMeshVector([1., 1., 1., 1.])

However, when doing the same thing with a scipy sparse matrix, the subclass type is destroyed, even though csr_matrix has an __array_priority__ = 10.1. This also happens when using the @ operator which is prefered since python 3.5.

>>> from scipy.sparse import csr_matrix

>>> a = FlattenedMeshVector([1,1,1,1])
>>> id_mat = csr_matrix(np.diag(np.ones(4)))
>>> id_mat.dot(a)
array([1., 1., 1., 1.])

>>> id_mat @ a
array([1., 1., 1., 1.])

I assume that csr_matrix.dot does some conversion to an ndarray at some point. Any idea how I might circumvent this?

Upvotes: 0

Views: 81

Answers (1)

hpaulj
hpaulj

Reputation: 231665

Signature: M.__mul__(other)

has several test cases

if other.__class__ is np.ndarray:
    self._mul_vector(other)     # or, depending on dimensions
    self._mul_multivector(other)

if issparse(other):
    self._mul_sparse_matrix(other)

other_a = np.asanyarray(other)      # anything else

after multiplication, it may turn the ndarray into np.matrix:

if isinstance(other, np.matrix):
     result = asmatrix(result)

So basically the 3 cases are:

In [646]: M@M
Out[646]: 
<11x11 sparse matrix of type '<class 'numpy.int64'>'
    with 31 stored elements in Compressed Sparse Row format>
In [647]: type([email protected])
Out[647]: numpy.ndarray
In [648]: type([email protected]())
Out[648]: numpy.matrix

It looks like in your case these are the same:

In [671]: id_mat@a
Out[671]: array([1., 1., 1., 1.])
In [672]: id_mat._mul_vector(a)
Out[672]: array([1., 1., 1., 1.])

_mul_vector does:

    result = np.zeros(M, dtype=upcast_char(self.dtype.char,
                                           other.dtype.char))
    # csr_matvec or csc_matvec
    fn = getattr(_sparsetools, self.format + '_matvec')
    fn(M, N, self.indptr, self.indices, self.data, other, result)

sparse._sparsetools.csr_matvec is "builtin", i.e. compiled, probably from cython code. In any case, result is a np.zeros with the right shape and dtype, and the calculated values.

So taking a clue from its handling of np.matrix, I think your only option is

In [678]: id_mat@(a)
Out[678]: array([1., 1., 1., 1.])
In [679]: FlattenedMeshVector(id_mat@(a))
Out[679]: FlattenedMeshVector([1., 1., 1., 1.])

Upvotes: 1

Related Questions