Basj
Basj

Reputation: 46483

Iterate over last axis of a numpy array

Let's say we have a (20, 5) array. We can iterate over each row very pythonically:

import numpy as np
xs = np.array(range(100)).reshape(20, 5)
for x in xs:
    print(x)

If we want to iterate over another axis (here in the example, iterate over columns, but I'm looking for a solution for each possible axis in a ndarray), it's less direct, we can use the method from Iterating over arbitrary dimension of numpy.array:

for i in range(xs.shape[-1]):
    x = xs[..., i]
    print(x)

Is there a more direct way to iterate over another axis, like (pseudo-code):

for x in xs.iterator(axis=-1):
    print(x) 

?

Upvotes: 3

Views: 1092

Answers (1)

itamar kanter
itamar kanter

Reputation: 1360

I think that as_strided from the stride tricks module should do the work here.

It creates a view into the array and not a copy (as stated by the docs).

Here is a simple demonstration of as_stided capabilities:

from numpy.lib.stride_tricks import as_strided
import numpy as np
xs = np.array(range(3 *3 * 4)).reshape(3,3, 4)
for x in xs:
    print(x)

output:

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
[[12 13 14 15]
 [16 17 18 19]
 [20 21 22 23]]
[[24 25 26 27]
 [28 29 30 31]
 [32 33 34 35]]

function to iterate over array specific axis:

def iterate_over_axis(arr, axis=0):
    strides = arr.strides
    strides_ = [strides[axis], *strides[0:axis], *strides[(axis+1):]]
    shape = arr.shape
    shape_ = [shape[axis], *shape[0:axis], *shape[(axis+1):]]
    return as_strided(arr,  strides=strides_, shape=shape_)

for x in iterate_over_axis(xs, axis=1):
    print(x)

output:

[[ 0  1  2  3]
 [12 13 14 15]
 [24 25 26 27]]
[[ 4  5  6  7]
 [16 17 18 19]
 [28 29 30 31]]
[[ 8  9 10 11]
 [20 21 22 23]
 [32 33 34 35]]

  

Upvotes: 1

Related Questions