Andi
Andi

Reputation: 4899

Numpy: Reshape/horizontally split 3D array into 4D array

I do have a 3D np.array like this:

arr3d = np.arange(36).reshape(3, 2, 6)

array([[[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]],

       [[12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]],

       [[24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]]])

I need to horizontally split every pane of arr3d into 3 chunks, such as:

np.array(np.hsplit(arr3d[0, :, :], 3))

array([[[ 0,  1],
        [ 6,  7]],

       [[ 2,  3],
        [ 8,  9]],

       [[ 4,  5],
        [10, 11]]])

This should then lead to a 4D array.

arr4d[0, :, :, :] should contain the new splitted 3D array of the first pane of the original 3D array (np.array(np.hsplit(arr3d[0, :, :], 3)))

The final result should look like this:

result = np.array(
    [
        [[[0, 1], [6, 7]], [[2, 3], [8, 9]], [[4, 5], [10, 11]]],
        [[[12, 13], [18, 19]], [[14, 15], [20, 21]], [[16, 17], [22, 23]]],
        [[[24, 25], [30, 31]], [[26, 27], [32, 33]], [[28, 29], [34, 35]]],
    ]
)

result.shape
(3, 3, 2, 2)

array([[[[ 0,  1],
         [ 6,  7]],

        [[ 2,  3],
         [ 8,  9]],

        [[ 4,  5],
         [10, 11]]],


       [[[12, 13],
         [18, 19]],

        [[14, 15],
         [20, 21]],

        [[16, 17],
         [22, 23]]],


       [[[24, 25],
         [30, 31]],

        [[26, 27],
         [32, 33]],

        [[28, 29],
         [34, 35]]]])

I am looking for a pythonic way to perform this reshaping/splitting.

Upvotes: 0

Views: 357

Answers (1)

Pierre D
Pierre D

Reputation: 26311

Try:

sh = arr3d.shape[:-1] + (3, -1)
arr4d = arr3d.reshape(*sh).swapaxes(1, 2)

>>> arr4d
array([[[[ 0,  1],
         [ 6,  7]],

        [[ 2,  3],
         [ 8,  9]],

        [[ 4,  5],
         [10, 11]]],


       [[[12, 13],
         [18, 19]],

        [[14, 15],
         [20, 21]],

        [[16, 17],
         [22, 23]]],


       [[[24, 25],
         [30, 31]],

        [[26, 27],
         [32, 33]],

        [[28, 29],
         [34, 35]]]])

Explanation

It's the last dimension (in your example, size 6) that you want to split into (3, -1). That's why we first reshape into (a, b, 3, -1) (where (a, b, _) is the shape of arr3d). But since you do a hsplit() of each row, then the actual shape you want is (a, 3, b, -1), so we need to swap axes 1 and 2 (more precisely: roll them, which we will see below for higher dimensions).

Another example

shape = 7, 2, 3*3
arr3d = np.arange(np.prod(shape)).reshape(*shape)
check = np.array([np.array(np.hsplit(arr3d[k], 3)) for k in range(shape[0)])

sh = arr3d.shape[:-1] + (3, -1)
arr4d = arr3d.reshape(*sh).swapaxes(1, 2)
>>> np.equal(arr4d, check).all()
True

Generalization to higher dimensions

shape = 4, 5, 2, 3*3
ar = np.arange(np.prod(shape)).reshape(*shape)
check = np.array([np.array(np.split(ar[k], 3, axis=-1)) for k in range(shape[0])])

# any dimension
sh = ar.shape[:-1] + (3, -1)
out = np.rollaxis(ar.reshape(*sh), -2, 1)
>>> np.equal(out, check).all()
True

Upvotes: 1

Related Questions