astrofrog
astrofrog

Reputation: 34091

How to prevent Numpy from splitting up an array-like object

If I consider the following simple class:

class Quantity(object):

    def __init__(self, value, unit):
        self.unit = unit
        self.value = value

    def __getitem__(self, key):
        return Quantity(self.value[key], unit=self.unit)

    def __len__(self):
        return len(self.value)

and create an instance:

import numpy as np
q = Quantity(np.array([1,2,3]), 'degree')
print(repr(np.array(q)))

Then if I pass this object to Numpy, it will split up the object into an object array of 3 Quantity instances:

array([<__main__.Quantity object at 0x1073a0d50>,
       <__main__.Quantity object at 0x1073a0d90>,
       <__main__.Quantity object at 0x1073a0dd0>], dtype=object)

This is due to the presence of the __len__ and __getitem__ methods - if I remove either of them, then the object does not get split up:

array(<__main__.Quantity object at 0x110a4e610>, dtype=object)

I would like to still keep __len__ and __getitem__, but is there a way to prevent Numpy from splitting up the object?

EDIT: I am interested in solutions other than making Quantity an ndarray sub-class

Upvotes: 3

Views: 183

Answers (1)

keflavich
keflavich

Reputation: 19215

Is this what you're looking for?

class Quantity(object):

    def __init__(self, value, unit):
        self.unit = unit
        self.value = value

    def __getitem__(self, key):
        return Quantity(self.value[key], unit=self.unit)

    def __len__(self):
        return len(self.value)

    def __array__(self):
        return self.value

np.array uses the __array__ method

In [11]: q
Out[11]: <__main__.Quantity at 0x1042bdf90>

In [12]: np.array(q)
Out[12]: array([ 1.,  2.,  3.])

In [13]: print(repr(np.array(q)))
array([ 1.,  2.,  3.])

In [14]: len(q)
Out[14]: 3

In [15]: q[1]
Out[15]: <__main__.Quantity at 0x1042bdd50>

In [16]: q[0]
Out[16]: <__main__.Quantity at 0x1042bdd90>

In [17]: q[0].value
Out[17]: 1.0

Upvotes: 2

Related Questions