EelkeSpaak
EelkeSpaak

Reputation: 2825

Easy way to collapse trailing dimensions of numpy array?

In Matlab, I can do the following:

X = randn(25,25,25);
size(X(:,:))

ans = 
    25   625

I often find myself wanting to quickly collapse the trailing dimensions of an array, and do not know how to do this in numpy.

I know I can do this:

In [22]: x = np.random.randn(25,25,25)
In [23]: x = x.reshape(x.shape[:-2] + (-1,))
In [24]: x.shape
Out[24]: (25, 625)

but x.reshape(x.shape[:-2] + (-1,)) is a lot less concise (and requires more information about x) than simply doing x(:,:).

I've obviously tried the analogous numpy indexing, but that does not work as desired:

In [25]: x = np.random.randn(25,25,25)
In [26]: x[:,:].shape
Out[26]: (25, 25, 25)

Any hints on how to collapse the trailing dimensions of an array in a concise manner?

Edit: note that I'm after the resulting array itself, not just its shape. I merely use size() and x.shape in the above examples to indicate what the array is like.

Upvotes: 6

Views: 2102

Answers (3)

hpaulj
hpaulj

Reputation: 231540

What is supposed to happen with a 4d or higher?

octave:7> x=randn(25,25,25,25);
octave:8> size(x(:,:))
ans =
      25   15625

Your (:,:) reduces it to 2 dimensions, combining the last ones. The last dimension is where MATLAB automatically adds and collapses dimensions.

In [605]: x=np.ones((25,25,25,25))

In [606]: x.reshape(x.shape[0],-1).shape  # like Joe's
Out[606]: (25, 15625)

In [607]: x.reshape(x.shape[:-2]+(-1,)).shape
Out[607]: (25, 25, 625)

Your reshape example does something different from MATLAB, it just collapses the last 2. Collapsing it down to 2 dimensions like MATLAB is a simpler expression.

The MATLAB is concise simply because your needs match it's assumptions. The numpy equivalent isn't quite so concise, but gives you more control

For example to keep the last dimension, or combine dimensions 2 by 2:

In [608]: x.reshape(-1,x.shape[-1]).shape
Out[608]: (15625, 25)
In [610]: x.reshape(-1,np.prod(x.shape[-2:])).shape
Out[610]: (625, 625)

What's the equivalent MATLAB?

octave:24> size(reshape(x,[],size(x)(2:end)))
ans =
15625      25
octave:31> size(reshape(x,[],prod(size(x)(3:end))))

Upvotes: 3

Joe Kington
Joe Kington

Reputation: 284810

You might find it a bit more concise to modify the shape attribute directly. For example:

import numpy as np

x = np.random.randn(25, 25, 25)
x.shape = x.shape[0], -1

print x.shape
print x

This is functionally equivalent to reshape (in the sense of data ordering, etc). Obviously, it still requires the same information about x's shape, but it is a more concise way of handling the reshape.

Upvotes: 2

Kasravnd
Kasravnd

Reputation: 107347

You can use np.hstack :

>>> np.hstack(x).shape
(25, 625)

np.hstack ake a sequence of arrays and stack them horizontally to make a single array.

Upvotes: 1

Related Questions