Helin
Helin

Reputation: 65

What is numpy.rollaxis doing?

import numpy as np
a=np.arange(8)    
a=a.reshape(2,2,2)
print(a)  

I can understand the answer is:

  [[[0 1]
      [2 3]]
     [[4 5]
      [6 7]]]

but when print(np.rollaxis(a,2)), I can't understand the answer:

 [[[0 2]
      [4 6]]
     [[1 3]
      [5 7]]]

and when print(np.rollaxis(a,2,1)),I can't understand the answer too:

 [[[0 2]
      [1 3]]
     [[4 6]
      [5 7]]]

What is the process of these rollaxis?

Upvotes: 0

Views: 679

Answers (1)

user3483203
user3483203

Reputation: 51175

You're making it hard on yourself by using an array with the same size along every axis, so it's difficult to see the transformation that rollaxis is performing. It is much easier to understand this operation on an array with varying sizes along each axis.


Here is a better example:

a = np.arange(8).reshape(4,2,1)

rollaxis takes the axis you specify, and moves it to a given position (default value is 0):

>>> a.shape
(4, 2, 1)
>>> np.rollaxis(a, 1).shape               # Rolls axis 1 to position 0
(2, 4, 1)
>>> np.rollaxis(a, 2).shape               # Rolls axis 2 to position 0
(1, 4, 2)

While this function is still supported, best practice is to use numpy.moveaxis, which behaves similarly, but does not have a default argument for the destination of an axis:

>>> np.moveaxis(a, 2).shape
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-87-77b5e96d3a20> in <module>()
----> 1 np.moveaxis(a, 2).shape

TypeError: moveaxis() missing 1 required positional argument: 'destination'

>>> np.moveaxis(a, 2, 0).shape
(1, 4, 2)

Upvotes: 2

Related Questions