duhaime
duhaime

Reputation: 27594

Numpy: reshaping multiple times makes objects unequal

I am trying to run some reshape operations in numpy but seem to be unable to accomplish something that seems quite straightforward.

The following works just fine:

import numpy as np

X = np.random.rand(55, 100, 3)

b = None
for i in range(X.shape[1]):
    r = X[:, i:i+1, :]
    b = r if not np.any(b) else np.concatenate((b, r), axis=1)

assert np.all(X == b.reshape(X.shape[0], X.shape[1], X.shape[2])) # succeeds

However, when I attempt to reshape r before adding it to the rows in b, I seem to be unable to reshape the final b into the shape of X:

import numpy as np

X = np.random.rand(55, 100, 3)

b = None
for i in range(X.shape[1]):
    r = X[:, i:i+1, :].reshape(1, X.shape[0] * X.shape[2])
    b = r if not np.any(b) else np.concatenate((b, r), axis=1)

assert np.all(X == b.reshape(X.shape[0], X.shape[1], X.shape[2])) # fails

I know there are better ways to accomplish this kind of operation. I'm simplifying a more complex situation.

Does anyone know how I can make the second assertion succeed while maintaining the general structure of the second snippet? Any suggestions would be very helpful!

Upvotes: 0

Views: 60

Answers (1)

Julien
Julien

Reputation: 15071

As you said, there are probably ways to do this better, but if you absolutely want to stick to your structure, you just need to re-shuffle your indices:

assert np.all(X == np.transpose(b.reshape(X.shape[1], X.shape[0], X.shape[2]), axes=(1,0,2)))

Upvotes: 1

Related Questions