Reputation: 46483
Let's say we have a (20, 5) array. We can iterate over each row very pythonically:
import numpy as np
xs = np.array(range(100)).reshape(20, 5)
for x in xs:
print(x)
If we want to iterate over another axis (here in the example, iterate over columns, but I'm looking for a solution for each possible axis in a ndarray), it's less direct, we can use the method from Iterating over arbitrary dimension of numpy.array:
for i in range(xs.shape[-1]):
x = xs[..., i]
print(x)
Is there a more direct way to iterate over another axis, like (pseudo-code):
for x in xs.iterator(axis=-1):
print(x)
?
Upvotes: 3
Views: 1092
Reputation: 1360
I think that as_strided from the stride tricks module should do the work here.
It creates a view into the array and not a copy (as stated by the docs).
Here is a simple demonstration of as_stided
capabilities:
from numpy.lib.stride_tricks import as_strided
import numpy as np
xs = np.array(range(3 *3 * 4)).reshape(3,3, 4)
for x in xs:
print(x)
output:
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]
[[24 25 26 27]
[28 29 30 31]
[32 33 34 35]]
function to iterate over array specific axis:
def iterate_over_axis(arr, axis=0):
strides = arr.strides
strides_ = [strides[axis], *strides[0:axis], *strides[(axis+1):]]
shape = arr.shape
shape_ = [shape[axis], *shape[0:axis], *shape[(axis+1):]]
return as_strided(arr, strides=strides_, shape=shape_)
for x in iterate_over_axis(xs, axis=1):
print(x)
output:
[[ 0 1 2 3]
[12 13 14 15]
[24 25 26 27]]
[[ 4 5 6 7]
[16 17 18 19]
[28 29 30 31]]
[[ 8 9 10 11]
[20 21 22 23]
[32 33 34 35]]
Upvotes: 1