Reputation: 39477
I am reading a text (see K-nearest neighors example)
which gives this line of code
dist_sq = np.sum((X[:,np.newaxis,:] - X[np.newaxis,:,:]) ** 2, axis=-1)
Here X is a numpy 10x2 array which represents 10 points in the 2D plane.
It was initialized like this:
X = np.random.rand(10, 2)
OK... The text claims this line computes the pairs of squared distances between the points.
I have no idea why this works and if it works. I tried understanding it but I just can't. I personally try to avoid such cryptic code. This is just not human IMHO. The text explains this code in some details but it seems I don't get that explanation either.
Also, axis=-1
adds up to the confusion.
Could someone decrypt this line of code?
Also, what is the point of saying e.g. X[:,np.newaxis,:]
, X[np.newaxis,:,:]
?
Isn't X[:,np.newaxis]
, X[np.newaxis,:]
enough? Isn't it doing the same?!
Also, from combinatorics, the squared distances count should be 10*9/2
or 10*10/2
(if we include equal points which have distance 0), but this dist_sq
is a 10x10x2 array. So this also adds up to the confusion?! Why 200 elements?!
Upvotes: 1
Views: 239
Reputation: 961
You could analysis different parts of your code simply.
Check X
shape: X.shape=(10, 2)
.What does X[np.newaxis,:,:]
do in this command?
It adds new dimension as first dimension of X and convert to (1, 10, 2)
dimension numpy array. Similarly X[:,np.newaxis,:]
creats (10, 1, 2)
numpy array.
(X[:,np.newaxis,:] - X[np.newaxis,:,:]) ** 2
has (10, 10, 2)
dimension.
How about: dist_sq = np.sum((X[:,np.newaxis,:] - X[np.newaxis,:,:]) ** 2, axis=-1)
. It calculates euclidean distance between each pair of points in X
for example:
Y =
array([[0.79410882, 0.38156374],
[0.93574123, 0.6510161 ]])
Results of (Y[:,np.newaxis,:] - Y[np.newaxis,:,:]) ** 2
has (2, 2, 2)
dimension and np.sum
do summation on specific dimension: which one : axis=-1
.
dist_sq = np.sum((Y[:,np.newaxis,:] - Y[np.newaxis,:,:]) ** 2, axis=-1)
dist_sq=
array([[0. , 0.09266431],
[0.09266431, 0. ]])
For example :
(0.79410882-0.93574123)**2 + (0.38156374-0.6510161)**2 = 0.09266431387197768
So final solution is a square matrix that is symmetrical.
Upvotes: 2