fabian789
fabian789

Reputation: 8412

Merge equally sized arrays into tiled big array

I'm having a hard time phrasing what I want, which is why I didn't find it on Google. Let me start with an example before formulating the general case.

Say we have 7 arrays a1, ..., a7, each of shape (4, 5). I want a new array where the 7 arrays are arranged like this:

a1 a2 a3
a4 a5 a6
a7 0  0

This array is of shape (3*4, 3*5) == (12, 15), 0 is np.zeros((4, 5)).

In general, I have C arrays a1, ..., aC, of shape (H, W), and I want to put them into an array of shape (h*H, w*W), where h = ceil(sqrt(C)) and w = ceil(C/h). The C arrays are stored as one (C, H, W) dimensional array.

What's the most elegant way to do this? I was hacking something together by iterating over the necessary indices but it's not nice so I stopped.

Speed is not top priority and the arrays are fairly small.

Upvotes: 3

Views: 75

Answers (1)

Divakar
Divakar

Reputation: 221574

Approach #1

Some permuting of axes and reshaping should do the job -

C,m,n = a.shape
h = int(np.ceil(np.sqrt(C)))
w = int(np.ceil(C/h))

out = np.zeros((h,w,m,n),dtype=a.dtype)
out.reshape(-1,m,n)[:C] = a
out = out.swapaxes(1,2).reshape(-1,w*n)

Sample input, output -

In [340]: a
Out[340]: 
array([[[55, 58],
        [75, 78]],

       [[78, 20],
        [94, 32]],

       [[47, 98],
        [81, 23]],

       [[69, 76],
        [50, 98]],

       [[57, 92],
        [48, 36]],

       [[88, 83],
        [20, 31]],

       [[91, 80],
        [90, 58]]])

In [341]: out
Out[341]: 
array([[55, 58, 78, 20, 47, 98],
       [75, 78, 94, 32, 81, 23],
       [69, 76, 57, 92, 88, 83],
       [50, 98, 48, 36, 20, 31],
       [91, 80,  0,  0,  0,  0],
       [90, 58,  0,  0,  0,  0]])

Approach #2

Simpler one with zeros-concatenation -

z = np.zeros((h*w-C,m,n),dtype=a.dtype)
out = np.concatenate((a,z)).reshape(h,w,m,n).swapaxes(1,2).reshape(-1,w*n)

That could be modified/simplified a bit by using zeros-padding with np.pad -

zp = np.pad(a,((0,h*w-C),(0,0),(0,0)),'constant')
out = zp.reshape(h,w,m,n).swapaxes(1,2).reshape(-1,w*n)

Upvotes: 3

Related Questions