Crazymage
Crazymage

Reputation: 114

How to understand the result of this np.einsum('kij',A)?

For example,

A = np.arange(24).reshape((2, 3, 4))
print np.einsum('ijk', A)

this is still A with no problem.

But if I do print np.einsum('kij', A) the shape is (3, 4, 2). Shouldn't it be (4, 2, 3)?

The result of print np.einsum('cab', A) shape is (4, 2, 3) with no problem too. Why is print np.einsum('kij', A) not the same?

Upvotes: 0

Views: 248

Answers (1)

ali_m
ali_m

Reputation: 74232

If you specify only a single set of subscripts, these are interpreted as the order of dimensions in the input array with respect to the output, not vice versa.

For example:

import numpy as np

A = np.arange(24).reshape((2, 3, 4))
B = np.einsum('kij', A)

i, j, k = np.indices(B.shape)

print(np.all(B[i, j, k] == A[k, i, j]))
# True

As @hpaulj pointed out in the comments, you can make the correspondence between the input and output dimensions more explicit by specifying both sets of subscripts:

# this is equivalent to np.einsum('kij', A)
print(np.einsum('kij->ijk', A).shape)
# (3, 4, 2)

# this is the behavior you are expecting
print(np.einsum('ijk->kij', A).shape)
# (4, 2, 3)

Upvotes: 2

Related Questions