Reputation: 831
In chapter 2 of "Python Data Science Handbook" by Jake VanderPlas, he computes the sum of squared differences of several 2-d points using the following code:
rand = np.random.RandomState(42)
X = rand.rand(10,2)
dist_sq = np.sum(X[:,np.newaxis,:] - X[np.newaxis,:,:]) ** 2, axis=-1)
Two questions:
Upvotes: 2
Views: 313
Reputation: 117681
Why is a third axis created? What is the best way to visualize what is going on?
The adding new dimensions before adding/subtracting trick is a relatively common one to generate all pairs, by using broadcasting (None
is the same as np.newaxis
here):
>>> a = np.arange(10)
>>> a[:,None]
array([[0],
[1],
[2],
[3],
[4],
[5],
[6],
[7],
[8],
[9]])
>>> a[None,:]
array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
>>> a[:,None] + 100*a[None,:]
array([[ 0, 100, 200, 300, 400, 500, 600, 700, 800, 900],
[ 1, 101, 201, 301, 401, 501, 601, 701, 801, 901],
[ 2, 102, 202, 302, 402, 502, 602, 702, 802, 902],
[ 3, 103, 203, 303, 403, 503, 603, 703, 803, 903],
[ 4, 104, 204, 304, 404, 504, 604, 704, 804, 904],
[ 5, 105, 205, 305, 405, 505, 605, 705, 805, 905],
[ 6, 106, 206, 306, 406, 506, 606, 706, 806, 906],
[ 7, 107, 207, 307, 407, 507, 607, 707, 807, 907],
[ 8, 108, 208, 308, 408, 508, 608, 708, 808, 908],
[ 9, 109, 209, 309, 409, 509, 609, 709, 809, 909]])
Your example does the same, just with 2-vectors instead of scalars at the innermost level:
>>> X[:,np.newaxis,:].shape
(10, 1, 2)
>>> X[np.newaxis,:,:].shape
(1, 10, 2)
>>> (X[:,np.newaxis,:] - X[np.newaxis,:,:]).shape
(10, 10, 2)
Thus we find that the 'magical subtraction' is just all combinations of the coordinate X
subtracted from each other.
Is there a more intuitive way to perform this calculation?
Yes, use scipy.spatial.distance.pdist
for pairwise distances. To get an equivalent result to your example:
from scipy.spatial.distance import pdist, squareform
dist_sq = squareform(pdist(X))**2
Upvotes: 2