Reputation: 65
import numpy as np
a=np.arange(8)
a=a.reshape(2,2,2)
print(a)
I can understand the answer is:
[[[0 1]
[2 3]]
[[4 5]
[6 7]]]
but when print(np.rollaxis(a,2)), I can't understand the answer:
[[[0 2]
[4 6]]
[[1 3]
[5 7]]]
and when print(np.rollaxis(a,2,1)),I can't understand the answer too:
[[[0 2]
[1 3]]
[[4 6]
[5 7]]]
What is the process of these rollaxis?
Upvotes: 0
Views: 679
Reputation: 51175
You're making it hard on yourself by using an array with the same size along every axis, so it's difficult to see the transformation that rollaxis
is performing. It is much easier to understand this operation on an array with varying sizes along each axis.
Here is a better example:
a = np.arange(8).reshape(4,2,1)
rollaxis
takes the axis you specify, and moves it to a given position (default value is 0):
>>> a.shape
(4, 2, 1)
>>> np.rollaxis(a, 1).shape # Rolls axis 1 to position 0
(2, 4, 1)
>>> np.rollaxis(a, 2).shape # Rolls axis 2 to position 0
(1, 4, 2)
While this function is still supported, best practice is to use numpy.moveaxis
, which behaves similarly, but does not have a default argument for the destination of an axis:
>>> np.moveaxis(a, 2).shape
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-87-77b5e96d3a20> in <module>()
----> 1 np.moveaxis(a, 2).shape
TypeError: moveaxis() missing 1 required positional argument: 'destination'
>>> np.moveaxis(a, 2, 0).shape
(1, 4, 2)
Upvotes: 2