Reputation: 93
I am trying to access systematically a numpy array's axis. For example, suppose I have an array
a = np.random.random((10, 10, 10, 10, 10, 10, 10))
# choosing 7:9 from axis 2
b = a[:, :, 7:9, ...]
# choosing 7:9 from axis 3
c = a[:, :, :, 7:9, ...]
Typing colons gets very repetitive if I have a high dimensional array. Now, I want some function choose_from_axis
such that
# choosing 7:9 from axis 2
b = choose_from_axis(a, 2, 7, 9)
# choosing 7:9 from axis 3
c = choose_from_axis(a, 3, 7, 9)
So, basically, I want to access an axis with a number. The only way I know how to do this is to use rollaxis
back and forth, but I am looking for a more direct way to do it.
Upvotes: 9
Views: 8395
Reputation: 362517
Sounds like you may be looking for take:
>>> a = np.random.randint(0,100, (3,4,5))
>>> a[:,1:3,:]
array([[[61, 4, 89, 24, 86],
[48, 75, 4, 27, 65]],
[[57, 55, 55, 6, 95],
[19, 16, 4, 61, 42]],
[[24, 89, 41, 74, 85],
[27, 84, 23, 70, 29]]])
>>> a.take(np.arange(1,3), axis=1)
array([[[61, 4, 89, 24, 86],
[48, 75, 4, 27, 65]],
[[57, 55, 55, 6, 95],
[19, 16, 4, 61, 42]],
[[24, 89, 41, 74, 85],
[27, 84, 23, 70, 29]]])
This will also give you support for tuple indexing. Example:
>>> a = np.arange(2*3*4).reshape(2,3,4)
>>> a
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> a[:,:,(0,1,3)]
array([[[ 0, 1, 3],
[ 4, 5, 7],
[ 8, 9, 11]],
[[12, 13, 15],
[16, 17, 19],
[20, 21, 23]]])
>>> a.take((0,1,3), axis=2)
array([[[ 0, 1, 3],
[ 4, 5, 7],
[ 8, 9, 11]],
[[12, 13, 15],
[16, 17, 19],
[20, 21, 23]]])
Upvotes: 8
Reputation: 69182
You could construct a slice object that does the job:
def choose_from_axis(a, axis, start, stop):
s = [slice(None) for i in range(a.ndim)]
s[axis] = slice(start, stop)
return a[s]
For example, the following both give the same result:
x[:,1:2,:]
choose_from_axis(x, 1, 1, 2)
# [[[ 3 4 5]]
# [[12 13 14]]
# [[21 22 23]]]
as does the example in the question:
a = np.random.random((10, 10, 10, 10, 10, 10, 10))
a0 = a[:, :, 7:9, ...]
a1 = choose_from_axis(a, 2, 7, 9)
print np.all(a0==a1) # True
Upvotes: 7