Wang
Wang

Reputation: 8214

Inserting newaxis at variable position in NumPy arrays

Normally, when we know where should we insert the newaxis, we can do a[:, np.newaxis,...]. Is there any good way to insert the newaxis at certain axis?

Here is how I do it now. I think there must be some much better ways than this:

def addNewAxisAt(x, axis):
    _s = list(x.shape)
    _s.insert(axis, 1)
    return x.reshape(tuple(_s))

def addNewAxisAt2(x, axis):
    ind = [slice(None)]*x.ndim
    ind.insert(axis, np.newaxis)
    return x[ind]

Upvotes: 4

Views: 454

Answers (2)

hpaulj
hpaulj

Reputation: 231615

np.insert does

slobj = [slice(None)]*ndim
...
slobj[axis] = slice(None, index)
...
new[slobj] = arr[slobj2]

Like you it constructs a list of slices, and modifies one or more elements.

apply_along_axis constructs an array, and converts it to indexing tuple

outarr[tuple(i.tolist())] = res

Other numpy functions work this way as well.

My suggestion is to make initial list large enough to hold the None. Then I don't need to use insert:

In [1076]: x=np.ones((3,2,4),int)

In [1077]: ind=[slice(None)]*(x.ndim+1)

In [1078]: ind[2]=None

In [1080]: x[ind].shape
Out[1080]: (3, 2, 1, 4)

In [1081]: x[tuple(ind)].shape   # sometimes converting a list to tuple is wise
Out[1081]: (3, 2, 1, 4)

Turns out there is a np.expand_dims

In [1090]: np.expand_dims(x,2).shape
Out[1090]: (3, 2, 1, 4)

It uses reshape like you do, but creates the new shape with tuple concatenation.

def expand_dims(a, axis):
    a = asarray(a)
    shape = a.shape
    if axis < 0:
        axis = axis + len(shape) + 1
    return a.reshape(shape[:axis] + (1,) + shape[axis:])

Timings don't tell me much about which is better. They are the 2 µs range, where simply wrapping the code in a function makes a difference.

Upvotes: 2

Divakar
Divakar

Reputation: 221674

That singleton dimension (dim length = 1) could be added as a shape criteria to the original array shape with np.insert and thus directly change its shape, like so -

x.shape = np.insert(x.shape,axis,1)

Well, we might as well extend this to invite more than one new axes with a bit of np.diff and np.cumsum trick, like so -

insert_idx = (np.diff(np.append(0,axis))-1).cumsum()+1
x.shape = np.insert(x.shape,insert_idx,1)

Sample runs -

In [151]: def addNewAxisAt(x, axis):
     ...:     insert_idx = (np.diff(np.append(0,axis))-1).cumsum()+1
     ...:     x.shape = np.insert(x.shape,insert_idx,1)
     ...:     

In [152]: A = np.random.rand(4,5)

In [153]: addNewAxisAt(A, axis=1)

In [154]: A.shape
Out[154]: (4, 1, 5)

In [155]: A = np.random.rand(5,6,8,9,4,2)

In [156]: addNewAxisAt(A, axis=5)

In [157]: A.shape
Out[157]: (5, 6, 8, 9, 4, 1, 2)

In [158]: A = np.random.rand(5,6,8,9,4,2,6,7)

In [159]: addNewAxisAt(A, axis=(1,3,4,6))

In [160]: A.shape
Out[160]: (5, 1, 6, 1, 1, 8, 1, 9, 4, 2, 6, 7)

Upvotes: 3

Related Questions