Reputation: 2978
I have the following numpy arrays:
import numpy as np
np.ones((10, 3, 2))
and I need to reshape it to <10,1,3,2>
.
How can I do so?
Upvotes: 1
Views: 106
Reputation: 21
Like others mentioned, you can .reshape
it. An alternative is to use np.newaxis
or np.expand_dims
like this:
arr = np.ones((10, 3, 2))
arr1 = arr[:, np.newaxis, ...]
print(arr1.shape) # (10, 1, 3, 2)
arr2 = np.expand_dims(arr, 1)
print(arr2.shape) # (10, 1, 3, 2)
# check if the two arrays are equal
print(np.array_equal(arr1, arr2)) # True
Upvotes: 0
Reputation: 1825
x = np.ones((10, 3, 2))
# in place
x.shape = (10,1,3,2)
# new view
x.reshape((10,1,3,2))
# Add new axis
x[:, np.newaxis, :, :]
Upvotes: 1