Reputation:
What is the easiest way to expand a given NumPy array over an extra dimension?
For example, suppose I have
>>> np.arange(4)
array([0, 1, 2, 3])
>>> _.shape
(4,)
>>> expand(np.arange(4), 0, 6)
array([[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3]])
>>> _.shape
(6, 4)
or this one, a bit more complicated:
>>> np.eye(2)
array([[ 1., 0.],
[ 0., 1.]])
>>> _.shape
(2, 2)
>>> expand(np.eye(2), 0, 3)
array([[[ 1., 0.],
[ 0., 1.]],
[[ 1., 0.],
[ 0., 1.]],
[[ 1., 0.],
[ 0., 1.]]])
>>> _.shape
(3, 2, 2)
Upvotes: 8
Views: 3192
Reputation: 5522
I think modifying the strides of the array makes it easy to write expand
:
def expand(arr, axis, length):
new_shape = list(arr.shape)
new_shape.insert(axis, length)
new_strides = list(arr.strides)
new_strides.insert(axis, 0)
return np.lib.stride_tricks.as_strided(arr, new_shape, new_strides)
The function returns a view of the original array, that doesn't take extra memory.
The stride corresponding to the new axis is 0, so that no matter the index for that axis values remain the same, essentially giving you the desired behaviour.
Upvotes: 2
Reputation: 19547
I would recommend np.tile.
>>> a=np.arange(4)
>>> a
array([0, 1, 2, 3])
>>> np.tile(a,(6,1))
array([[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3]])
>>> b= np.eye(2)
>>> b
array([[ 1., 0.],
[ 0., 1.]])
>>> np.tile(b,(3,1,1))
array([[[ 1., 0.],
[ 0., 1.]],
[[ 1., 0.],
[ 0., 1.]],
[[ 1., 0.],
[ 0., 1.]]])
Expanding in many dimensions is pretty easy also:
>>> np.tile(b,(2,2,2))
array([[[ 1., 0., 1., 0.],
[ 0., 1., 0., 1.],
[ 1., 0., 1., 0.],
[ 0., 1., 0., 1.]],
[[ 1., 0., 1., 0.],
[ 0., 1., 0., 1.],
[ 1., 0., 1., 0.],
[ 0., 1., 0., 1.]]])
Upvotes: 6