pbreach
pbreach

Reputation: 17017

Splitting an N dimensional numpy array into multiple 1D arrays

I have a simulation model that integrates a set of variables whose states are represented by numpy arrays of an arbitrary number of dimensions. After the simulation, I now have a list of arrays whose elements represent the variable state at a particular point in time.

In order to output the simulation results I want to split these arrays into multiple 1D arrays where the elements correspond to the same component of the state variable through time. Here is an example of a 2D state variable over a number of time steps.

import numpy as np

# Arbitrary state that is constant
arr = np.arange(9).reshape((3, 3))

# State variable through 3 time steps
state = [arr.copy() for _ in range(3)]

# Stack the arrays up to 3d. Axis could be rolled here if it makes it easier.
stacked = np.stack(state)

The output I need to get is:

[np.array([0, 0, 0]), np.array([1, 1, 1]), np.array([2, 2, 2]), ...]

I've tried doing np.split(stacked, sum(stacked.shape[:-1]), axis=...) (tried everything for axis=) but get the following error: ValueError: array split does not result in an equal division. Is there a way to do this using np.split or maybe np.nditer that will work for the general case?

I guess this would be equivalent to doing:

I, J, K = stacked.shape

result = []

for i in range(I):
    for j in range(J):
        result.append(stacked[i, j, :])

Which is also the ordering I'm hoping to get. Easy enough, however I'm hoping there is something in numpy that I can take advantage of for this that will be more general.

Upvotes: 5

Views: 5490

Answers (2)

hpaulj
hpaulj

Reputation: 231510

If I reshape it to a 9x3 array, then a simple list() will turn it into a list of 3 element arrays:

In [190]: stacked.reshape(-1,3)
Out[190]: 
array([[0, 0, 0],
       [1, 1, 1],
       [2, 2, 2],
       [3, 3, 3],
       [4, 4, 4],
       [5, 5, 5],
       [6, 6, 6],
       [7, 7, 7],
       [8, 8, 8]])
In [191]: list(stacked.reshape(-1,3))
Out[191]: 
[array([0, 0, 0]),
 array([1, 1, 1]),
 array([2, 2, 2]),
 array([3, 3, 3]),
 array([4, 4, 4]),
 array([5, 5, 5]),
 array([6, 6, 6]),
 array([7, 7, 7]),
 array([8, 8, 8])]

np.split(stacked.reshape(-1,3),9) produces a list of 1x3 arrays.

np.split only works on one axis, but you want to split on the 1st 2 - hence the need for a reshape or ravel.

And forget about nditer. That's a stepping stone to reworking code in cython. It does not help with ordinary iteration - except that when used in ndindex it can streamline your i,j double loop:

In [196]: [stacked[idx] for idx in np.ndindex(stacked.shape[:2])]
Out[196]: 
[array([0, 0, 0]),
 array([1, 1, 1]),
 array([2, 2, 2]),
 array([3, 3, 3]),
 array([4, 4, 4]),
 array([5, 5, 5]),
 array([6, 6, 6]),
 array([7, 7, 7]),
 array([8, 8, 8])]

======================

With the different state, just stack on a different axis

In [302]: state
Out[302]: 
[array([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]), array([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]), array([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])]
In [303]: np.stack(state,axis=2).reshape(-1,3)
Out[303]: 
array([[0, 0, 0],
       [1, 1, 1],
       [2, 2, 2],
       [3, 3, 3],
       [4, 4, 4],
       [5, 5, 5],
       [6, 6, 6],
       [7, 7, 7],
       [8, 8, 8]])

stack is rather like np.array, except it gives more control over where the dimension is added. But do look at it's code.

Upvotes: 3

Divakar
Divakar

Reputation: 221614

You could use np.split on a flattened version and cut into appropriate number of parts as 1D lists, like so -

np.split(stacked.ravel(),np.prod(stacked.shape[:2]))

Sample run -

In [406]: stacked
Out[406]: 
array([[[0, 0, 0],
        [1, 1, 1]],

       [[2, 2, 2],
        [3, 3, 3]],

       [[4, 4, 4],
        [5, 5, 5]],

       [[6, 6, 6],
        [7, 7, 7]]])

In [407]: np.split(stacked.ravel(),np.prod(stacked.shape[:2]))
Out[407]: 
[array([0, 0, 0]),
 array([1, 1, 1]),
 array([2, 2, 2]),
 array([3, 3, 3]),
 array([4, 4, 4]),
 array([5, 5, 5]),
 array([6, 6, 6]),
 array([7, 7, 7])]

Upvotes: 2

Related Questions