Tim
Tim

Reputation: 2049

Are there rules for the interaction between numpy reshape() and transpose()?

I've put this question in quite a bit of context, to hopefully make it easier to understand, but feel free to skip down to the actual question.


Context

Here is the work I was doing which sparked this question:

I'm working with an API to access some tabular data, which is effectively a labelled N-dimensional array. The data is returned as a flattened list of lists (of the actual data values), plus a list of the different axes and their labels, e.g.:

raw_data = [
    ['nrm', 'nrf'],
    ['ngm', 'ngf'],
    ['nbm', 'nbf'],
    ['srm', 'srf'],
    ['sgm', 'sgf'],
    ['sbm', 'sbf'],
    ['erm', 'erf'],
    ['egm', 'egf'],
    ['ebm', 'ebf'],
    ['wrm', 'wrf'],
    ['wgm', 'wgf'],
    ['wbm', 'wbf'],
]

axes = [
    ('Gender', ['Male', 'Female']),
    ('Color', ['Red', 'Green', 'Blue']),
    ('Location', ['North', 'South', 'East', 'West']),
]

The data is normally numeric, but I've used strings here so you can easily see how it matches up with the labels, e.g. nrm is the value for North, Red, Male.

The data loops through axis 0 as you go across (within) a list, and then loops through axes 1 and 2 as you go down the lists, with axis 1 (on the "inside") varying most rapidly, then 2 (and for higher-dimensional data continuing to work "outwards"), viz:

       axis 0 ->
a a [ # # # # # # ]
x x [ # # # # # # ]
i i [ # # # # # # ]
s s [ #  R A W  # ]
    [ # D A T A # ]
2 1 [ # # # # # # ]
↓ ↓ [ # # # # # # ]
    [ # # # # # # ]

I want to reshape this data and match it up with its labels, which I did using the following to output it into a Pandas (multi-index) DataFrame:

import numpy as np
import pandas as pd

names = [name for (name, _) in axes]
labels = [labels for (_, labels) in axes]

sizes = tuple(len(L) for L in labels)  # (2, 3, 4)
data_as_array = np.array(raw_data)  # shape = (12, 2) = (3*4, 2)
A = len(sizes)  # number of axes
new_shape = (*sizes[1:],sizes[0])  # (3, 4, 2)

data = data_as_array.reshape(new_shape, order="F").transpose(A - 1, *range(A - 1))
# With my numbers: data_as_array.reshape((3, 4, 2), order="F").transpose(2, 0, 1)

df = pd.DataFrame(
    data.ravel(),
    index=pd.MultiIndex.from_product(labels, names=names),
    columns=["Value"],
)

(I've noted with comments what some of the particular values are for my example, but the code is meant to be generalised for any N-dimensional data.)

This gives:

                      Value
Gender Color Location      
Male   Red   North      nrm
             South      srm
             East       erm
             West       wrm
       Green North      ngm
             South      sgm
             East       egm
             West       wgm
       Blue  North      nbm
             South      sbm
             East       ebm
             West       wbm
Female Red   North      nrf
             South      srf
             East       erf
             West       wrf
       Green North      ngf
             South      sgf
             East       egf
             West       wgf
       Blue  North      nbf
             South      sbf
             East       ebf
             West       wbf

This is all as desired and expected, and you can see that the values have ended up in the correct places, i.e. attached to their matching labels.


Question

My actual question concerns this line:

data = data_as_array.reshape(new_shape, order="F").transpose(A - 1, *range(A - 1))

which with the specific numbers in my example was:

data = data_as_array.reshape((3, 4, 2), order="F").transpose(2, 0, 1)

After some experimentation, I discovered that all three of the following are equivalent (the first is the original version):

data1 = data_as_array.reshape(new_shape, order="F").transpose(D - 1, *range(D - 1))
data2 = data_as_array.T.reshape(*reversed(new_shape)).T.transpose(D - 1, *range(D - 1))
data3 = data_as_array.reshape(*reversed(sizes)).T

But this got me thinking (and here is my question at last!):

Are there any rules that I could use to manipulate the expression, to get from e.g. data1 to data3?

In particular, it seems like transpose() and reshape() are closely linked and that there might be a way to "absorb" the action of the tranpose into the reshape(), so that you can drop it or at least transform it into a neater .T (as per data3).


My attempt

I managed to establish the following rule:

a.reshape(shape, order="F") == a.T.reshape(*reversed(shape)).T

You can apply .T to both sides, or substitute a.T in for a to get these variations of it:

a.reshape(shape) == a.T.reshape(*reversed(shape), order="F").T
a.reshape(shape).T == a.T.reshape(*reversed(shape), order="F")
a.T.reshape(shape) == a.reshape(*reversed(shape), order="F").T

a.reshape(shape, order="F") == a.T.reshape(*reversed(shape)).T
a.reshape(shape, order="F").T == a.T.reshape(*reversed(shape))
a.T.reshape(shape, order="F") == a.reshape(*reversed(shape)).T

I think this is effectively the definition of the difference between row-major and column-major ordering, and how they relate.

But what I haven't managed to do is show is how you can go from:

data = data_as_array.reshape((3, 4, 2), order="F").transpose(2, 0, 1)

to:

data = data_as_array.reshape((4, 3, 2))

So somehow put the transposition into the reshape.

But I'm not even sure if this is generally true, or is specific to my data or e.g. 3 dimensions.

EDIT: To clarify, I'm reasonably happy with how a straight-up .T transpose works, and the rules above cover that. (Note that .T is equivalent to .tranpose(2, 1, 0) for 3 axes, or .tranpose(n-1, n-2, ... 2, 1, 0) for the general case of n axes.)

It's the case of using .transpose() where you're doing a "partial" transpose that I'm curious about, e.g. .tranpose(1, 0, 2) - where you're doing something other than just reversing the order of the axes.


Some references:

Upvotes: 2

Views: 561

Answers (1)

hpaulj
hpaulj

Reputation: 231510

I'm not going to try to go through all your cases (for now), but here's an illustration of how reshape, transpose, and order interact:

In [176]: x = np.arange(12)                                                                                  
In [177]: x.strides, x.shape                                                                                 
Out[177]: ((8,), (12,))
In [178]: y = x.reshape(3,4)                                                                                 
In [179]: y.strides, y.shape                                                                                 
Out[179]: ((32, 8), (3, 4))        # (32=4*8)
In [180]: z = y.T                                                                                            
In [181]: z.strides, z.shape                                                                                 
Out[181]: ((8, 32), (4, 3))         # strides has been switched
In [182]: w = x.reshape(4,3, order='F')                                                                      
In [183]: w.strides, w.shape                                                                                 
Out[183]: ((8, 32), (4, 3))
In [184]: z                                                                                                  
Out[184]: 
array([[ 0,  4,  8],
       [ 1,  5,  9],
       [ 2,  6, 10],
       [ 3,  7, 11]])
In [185]: w                                                                                                  
Out[185]: 
array([[ 0,  4,  8],
       [ 1,  5,  9],
       [ 2,  6, 10],
       [ 3,  7, 11]])

The reshape with 'F' produces the same thing as a transpose.

ravel, which is essentially reshape(-1) (to 1d)

In [186]: w.ravel()     # order C                                                                                          
Out[186]: array([ 0,  4,  8,  1,  5,  9,  2,  6, 10,  3,  7, 11])
In [187]: w.ravel(order='F')                                                                                 
Out[187]: array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

Note that w (and z) is a view of x:

In [190]: w.base                                                                                             
Out[190]: array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
In [191]: x.__array_interface__                                                                              
Out[191]: 
{'data': (139649452400704, False),
 'strides': None,
 'descr': [('', '<i8')],
 'typestr': '<i8',
 'shape': (12,),
 'version': 3}
In [192]: w.__array_interface__                                                                              
Out[192]: 
{'data': (139649452400704, False),   # same data buffer address
 'strides': (8, 32),
 'descr': [('', '<i8')],
 'typestr': '<i8',
 'shape': (4, 3),
 'version': 3}

for a partial transpose:

In [194]: x = np.arange(24)                                                                                  
In [195]: y = x.reshape(2,3,4)                                                                               
In [196]: y.strides                                                                                          
Out[196]: (96, 32, 8)
In [197]: z = y.transpose(1,0,2)                                                                             
In [198]: z                                                                                                  
Out[198]: 
array([[[ 0,  1,  2,  3],
        [12, 13, 14, 15]],

       [[ 4,  5,  6,  7],
        [16, 17, 18, 19]],

       [[ 8,  9, 10, 11],
        [20, 21, 22, 23]]])
In [199]: z.shape                                                                                            
Out[199]: (3, 2, 4)
In [200]: z.strides                                                                                          
Out[200]: (32, 96, 8)

The partial transpose has permuted shape and strides. The result is neither order F or C.

Element order in the base:

In [201]: z.ravel(order='K')                                                                                 
Out[201]: 
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23])

Order going by rows:

In [202]: z.ravel(order='C')                                                                                 
Out[202]: 
array([ 0,  1,  2,  3, 12, 13, 14, 15,  4,  5,  6,  7, 16, 17, 18, 19,  8,
        9, 10, 11, 20, 21, 22, 23])

order going by columns:

In [203]: z.ravel(order='F')                                                                                 
Out[203]: 
array([ 0,  4,  8, 12, 16, 20,  1,  5,  9, 13, 17, 21,  2,  6, 10, 14, 18,
       22,  3,  7, 11, 15, 19, 23])

Upvotes: 2

Related Questions