Kater Tatianow
Kater Tatianow

Reputation: 3

Overloading operators in numpy

I have the following problem: I wish to create and use numpy array with slight change in operator []. I understand so far it is done by the method __getitem__(self, index). However I am unable to figure how to do it so I declare an array that is "numpy array" in every aspect except of that one issue (say for sake of example I want array[i] to be interpreted as array[i-1]

I tried to solve it following way:

class myarray(np.ndarray):
def __getitem__(self, index):
    return self[index+1]

k = np.linspace(0, 10, 10).view(myarray)

though it's not really working

Upvotes: 0

Views: 1107

Answers (2)

Kater Tatianow
Kater Tatianow

Reputation: 3

Thanks to onodip answer, I've solved my initial problem. It was slightly different from what I posted, I've learnt to be more specific in future ( not to ask by an example).

I originally wanted to iterate over matrix "in loops" - making n+1 index point to 0 and so forth for all indices - in other words considering them modulo shape.

import numpy as np


class myarray(np.ndarray):

def __getitem__(self, index): 

    if isinstance(index, tuple):
        new_index = tuple(index[i] % super(myarray, self).shape[i] for i in range(len(index)))
    else:
        new_index = index % super(myarray, self).shape[0]

    return super(myarray, self).__getitem__(new_index)


my_k = np.linspace(0, 10, 10).view(myarray)
print(my_k)
print(my_k[7])
print(my_k[17])

It's been a great lesson for me. Thanks all for answers and your time!

Upvotes: 0

onodip
onodip

Reputation: 655

There are two problems with your code. The first, that the index can be also a tuple (not just an int). The other is, that in the return of your function you are getting the item with [], which also uses getitem. This will lead to an infinite recursion. You have to use the function of the parent class with super()

import numpy as np


class myarray(np.ndarray):

    def __getitem__(self, index):
        if isinstance(index, tuple):
            index = index[0] + 1,
        else:
            index += 1
        return super(myarray, self).__getitem__(index)

my_k = np.linspace(0, 10, 10).view(myarray)
k = np.linspace(0, 10, 10).view(np.ndarray)
print(my_k)
print(k)

Upvotes: 1

Related Questions