y.selivonchyk
y.selivonchyk

Reputation: 9900

How to collapse two array axis together of a numpy array?

Basic idea: I have an array of images images=np.array([10, 28, 28, 3]). So 10 images 28x28 pixels with 3 colour channels. I want to stitch them together in one long line: single_image.shape # [280, 28, 3]. What would be the best numpy based function for that?

More generally: is there a function along the lines of stitch(array, source_axis=0, target_axis=1) that would transform an array A.shape # [a0, a1, source_axis, a4, target_axis, a6] into a shape B.shape # [a0, a1, a4, target_axis*source_axis, a6] by concatenating subarrays A[:,:,i,:,:,:] along axis=target_axis

Upvotes: 1

Views: 446

Answers (2)

Divakar
Divakar

Reputation: 221584

You can set it up with a single moveaxis + reshape combo -

def merge_axis(array, source_axis=0, target_axis=1):
    shp = a.shape
    L = shp[source_axis]*shp[target_axis] # merged axis len
    out_shp = np.insert(np.delete(shp,(source_axis,target_axis)),target_axis-1,L)
    return np.moveaxis(a,source_axis,target_axis-1).reshape(out_shp)

Alternatively, out_shp could be setup with array manipulations and might be easier to follow, like so -

shp = np.array(a.shape)
shp[target_axis] *= shp[source_axis]
out_shp = np.delete(shp,source_axis)

If source and target axes are adjacent ones, we can skip moveaxis and simply reshape and the additional benefit would be that the output would be a view into the input and hence virtually free on runtime. So, we will introduce a If-conditional to check and modify our implementations to something like these -

def merge_axis_v1(array, source_axis=0, target_axis=1):
    shp = a.shape
    L = shp[source_axis]*shp[target_axis] # merged_axis_len
    out_shp = np.insert(np.delete(shp,(source_axis,target_axis)),target_axis-1,L)
    if target_axis==source_axis+1:
        return a.reshape(out_shp)
    else:
        return np.moveaxis(a,source_axis,target_axis-1).reshape(out_shp)

def merge_axis_v2(array, source_axis=0, target_axis=1):
    shp = np.array(a.shape)
    shp[target_axis] *= shp[source_axis]
    out_shp = np.delete(shp,source_axis)
    if target_axis==source_axis+1:
        return a.reshape(out_shp)
    else:
        return np.moveaxis(a,source_axis,target_axis-1).reshape(out_shp)

Verify views -

In [156]: a = np.random.rand(10,10,10,10,10)

In [157]: np.shares_memory(merge_axis_v1(a, source_axis=0, target_axis=1),a)
Out[157]: True

Upvotes: 2

y.selivonchyk
y.selivonchyk

Reputation: 9900

Here is my take:

def merge_axis(array, source_axis=0, target_axis=1):
    array = np.moveaxis(array, source_axis, 0)
    array = np.moveaxis(array, target_axis, 1)
    array = np.concatenate(array)
    array = np.moveaxis(array, 0, target_axis-1)
    return array

Upvotes: 1

Related Questions