dbrane
dbrane

Reputation: 977

Numpy functions clobber my inherited datatype

Say I have a class ndarray_plus that inherits from numpy.ndarray and adds some extra functionality. Sometimes I pass it to numpy functions like np.sum and get back an object of type ndarray_plus, as expected.

Other times, numpy functions that I pass my enhanced object to return an numpy.ndarray object, destroying the information in the extra ndarray_plus attributes. This happens usually when the numpy function in question does a np.asarray instead of np.asanyarray.

Is there a way to prevent this from happening? I can't go into the numpy codebase and change all instances of np.asarray to np.asanyarray. Is there a Pythonic way to pre-emptively protect my inherited object?

Upvotes: 4

Views: 225

Answers (1)

Paul Panzer
Paul Panzer

Reputation: 53089

The defined and guaranteed behaviour of asarray is to convert your subclass instance back to base class

help on function asarray in numpy:

numpy.asarray = asarray(a, dtype=None, order=None)
Convert the input to an array.

Parameters
----------
a : array_like
    Input data, in any form that can be converted to an array.  This
    includes lists, lists of tuples, tuples, tuples of tuples, tuples
    of lists and ndarrays.
dtype : data-type, optional
    By default, the data-type is inferred from the input data.
order : {'C', 'F'}, optional
    Whether to use row-major (C-style) or
    column-major (Fortran-style) memory representation.
    Defaults to 'C'.

Returns
-------
out : ndarray
    Array interpretation of `a`.  No copy is performed if the input
    is already an ndarray.  If `a` is a subclass of ndarray, a base
    class ndarray is returned.

See Also
--------
asanyarray : Similar function which passes through subclasses.

< - snip - >

You could try and monkeypatch:

>>> import numpy as np
>>> import mpt
>>> 
>>> s = np.matrix(3)
>>> mpt.aa(s)
array([[3]])
>>> np.asarray = np.asanyarray
>>> mpt.aa(s)
matrix([[3]])

file mpt.py

import numpy as np

def aa(x):
   return np.asarray(x)

Sadly, this doesn't always work.

alternative mpt.py

from numpy import asarray

def aa(x):
   return asarray(x)

here you'd be out of luck.

Upvotes: 1

Related Questions