6659081
6659081

Reputation: 401

Easy way to fold a multidimensional NumPy array

I have a NumPy 2D array:

a = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])

where each column of this array (i.e.: [1,3, 5, 7], and [2, 4, 6, 8]) need to be transformed to a matrix of given size M1xM2 (using order='F' when reshaping). In this case, M1 = M2 = 2. So, my desired output would be:

b = np.array([[[1, 5], [3, 7]], [[2, 6], [4, 8]]]).

I can easily achieve this by iterating over the columns. However, the number of columns can be any and the initial 2D array can be of up to 8 dimensions. How can I easily extend this solution to be used for more dimensions?

I suspect this is a common procedure and there is a built-in function to solve it, but haven't been able to find it.

Upvotes: 2

Views: 1442

Answers (2)

Nils Werner
Nils Werner

Reputation: 36795

It is as simple as reshape, then transpose:

a.reshape(2, 2, -1).T
# array([[[1, 5],
#         [3, 7]],
# 
#        [[2, 6],
#         [4, 8]]])

Upvotes: 3

Michael Szczesny
Michael Szczesny

Reputation: 5026

You can use reshape, swapaxes, reshape. @divakar posted a detailed explanation.

a.T.reshape(2,2,2,1).swapaxes(1,2).reshape(2,2,-1)

Out:

array([[[1, 5],
        [3, 7]],

       [[2, 6],
        [4, 8]]])

Upvotes: 2

Related Questions