Reputation: 1751
I have a very simple problem.
Given a N dimension point (say, a vector where each element represents a dimension) represented by x
and a MxN dimension matrix (or a group of M points that have N dimensions!) represented by y
.
set.seed(999)
data <- matrix(runif(1100), nrow = 11, ncol = 10)
x <- data[1, ]
y <- data[2:nrow(data), ]
I want to calculate a distance measure between x
and every point of y
. I know a simple way of doing that is to do:
distances <- dist(rbind(x, y))
However, I believe this is not very efficient for this specific case, for the following reasons:
dist
calculates the distance between every point, but I'm only interested in 10 of those distances, or simply the distance between x
and every point of y
. I'm not interested in the internal distances between y
points.One possible solution I thought of was to apply the distance measurement manually looping through y
.
distances <- apply(y, MARGIN = 1, function(a, b = x) {
sqrt(sum((a - b)^2))
})
However, when timing both approaches, I get:
func1 <- function(x, y) {
apply(y, MARGIN = 1, function(a, b = x) {
sqrt(sum((a - b)^2))
})
}
func2 <- function(x, y) {
dist(rbind(x, y))
}
microbenchmark::microbenchmark(
func1(x, y),
func2(x, y)
)
Unit: microseconds
expr min lq mean median uq max neval
func1(x, y) 29.602 30.450 61.21791 31.301 32.3510 2916.101 100
func2(x, y) 15.101 15.801 28.55304 17.201 17.7015 1143.001 100
So my question here is: is there a way to solve this problem faster than using dist
?
Upvotes: 3
Views: 413
Reputation: 1525
Update 2: If we assume we have complete data and the accuracy is with we can implement an even faster version of the distances using rcpp. I have added this below and the fastest version uses the bytecode compiler. For those with experience with RcppParallel this can likely be improved further.
Update: The function rdist from the fields package is by far the fastest method found so far. (see Calculating all distances between one point and a group of points efficiently in R). It appears to be fastest when not using the bytecode compiler also.
Briefly performing some testing on the previous results I obtained that vapply is faster when using the bytecode compiler than all other methods (after the first run when it compiles the functions for the first time this is why the maxtime is greater during the bytecode runs).
I have tried the methods from @akrun and @ThomasIsCoding here also.
library(microbenchmark)
library(compiler)
library(collapse)
library(fields)
library(Rcpp)
set.seed(999)
data <- matrix(runif(1100), nrow = 11, ncol = 10)
x <- data[1, ]
y <- data[2:nrow(data), ]
distances <- dist(rbind(x, y))
func1 <- function(x, y) {
apply(y, MARGIN = 1, function(a, b = x) {
sqrt(sum((a - b)^2))
})
}
func2 <- function(x, y) {
dist(rbind(x, y))
}
func3 <- function(x, y) {
dapply(y, function(a, b = x) {
sqrt(sum((a-b)^2))
}, MARGIN = 1)
}
func4 <- function(x, y) {
vapply(seq_len(nrow(y)), function(i, b = x) sqrt(sum((y[i,]-b)^2)), numeric(1))
}
func5 <- function(x, y) {
rdist(rbind(x, y))
}
cppFunction('NumericVector func6(NumericVector x, NumericVector y) {
int n = x.size();
int n2 = y.size();
int maxiters = n2/n;
NumericVector results(maxiters);
for(int i = 0; i < maxiters; i++) {
results[i] = 0;
for(int j = 0; j < n; j++) {
double val = x[j] - y[j * maxiters + i];
results[i] += val * val;
}
results[i] = sqrt(results[i]);
}
return results;
}')
func7 <- function(x, y) sqrt(rowSums((y-x[col(y)])^2))
func8 <- function(x, y) sqrt(colSums((t(y) - x)^2))
compiler::enableJIT(0)
#> [1] 3
microbenchmark::microbenchmark(
func1(x, y),
func2(x, y),
func3(x, y),
func4(x, y),
func5(x, y),
func6(x, y),
func7(x, y),
func8(x, y)
)
#>Unit: microseconds
#> expr min lq mean median uq max neval
#> func1(x, y) 37.001 42.8010 50.53103 45.4520 53.6515 138.302 100
#> func2(x, y) 20.201 25.3510 30.23096 27.8515 31.4010 70.401 100
#> func3(x, y) 23.901 27.6510 55.45699 30.0010 35.7505 2248.902 100
#> func4(x, y) 20.501 23.2010 28.20101 24.6020 31.4010 119.501 100
#> func5(x, y) 6.100 8.6020 19.27804 9.6515 11.4510 891.001 100
#> func6(x, y) 1.501 2.4010 11.60706 2.9010 3.4510 848.102 100
#> func7(x, y) 14.401 17.2505 27.73793 19.7510 23.2510 596.002 100
#> func8(x, y) 18.901 22.5510 27.91699 24.9015 29.3010 73.301 100
compiler::enableJIT(3)
#> [1] 0
microbenchmark::microbenchmark(
func1(x, y),
func2(x, y),
func3(x, y),
func4(x, y),
func5(x, y),
func6(x, y),
func7(x, y),
func8(x, y)
)
#>Unit: microseconds
#> expr min lq mean median uq max neval
#> func1(x, y) 32.100 35.9510 85.49213 39.4015 44.2510 4298.002 100
#> func2(x, y) 19.701 23.6010 45.11697 26.0505 29.6005 1732.702 100
#> func3(x, y) 19.801 22.2515 76.96108 24.8010 27.6510 5023.201 100
#> func4(x, y) 16.302 19.2510 77.46094 20.3010 21.8005 5564.701 100
#> func5(x, y) 6.201 8.5010 41.53397 9.4510 11.0510 3032.301 100
#> func6(x, y) 1.401 2.3010 13.95802 2.7005 3.0020 1101.801 100
#> func7(x, y) 14.201 16.7010 64.09999 18.6510 21.0015 4307.901 100
#> func8(x, y) 19.201 22.4500 64.33288 24.8510 27.5010 3776.101 100
Created on 2021-04-04 by the reprex package (v2.0.0)
Just the results
#ordinary compiler
#>Unit: microseconds
#> expr min lq mean median uq max neval
#> func1(x, y) 37.001 42.8010 50.53103 45.4520 53.6515 138.302 100
#> func2(x, y) 20.201 25.3510 30.23096 27.8515 31.4010 70.401 100
#> func3(x, y) 23.901 27.6510 55.45699 30.0010 35.7505 2248.902 100
#> func4(x, y) 20.501 23.2010 28.20101 24.6020 31.4010 119.501 100
#> func5(x, y) 6.100 8.6020 19.27804 9.6515 11.4510 891.001 100
#> func6(x, y) 1.501 2.4010 11.60706 2.9010 3.4510 848.102 100
#> func7(x, y) 14.401 17.2505 27.73793 19.7510 23.2510 596.002 100
#> func8(x, y) 18.901 22.5510 27.91699 24.9015 29.3010 73.301 100
#bytecode compiler
#>Unit: microseconds
#> expr min lq mean median uq max neval
#> func1(x, y) 32.100 35.9510 85.49213 39.4015 44.2510 4298.002 100
#> func2(x, y) 19.701 23.6010 45.11697 26.0505 29.6005 1732.702 100
#> func3(x, y) 19.801 22.2515 76.96108 24.8010 27.6510 5023.201 100
#> func4(x, y) 16.302 19.2510 77.46094 20.3010 21.8005 5564.701 100
#> func5(x, y) 6.201 8.5010 41.53397 9.4510 11.0510 3032.301 100
#> func6(x, y) 1.401 2.3010 13.95802 2.7005 3.0020 1101.801 100
#> func7(x, y) 14.201 16.7010 64.09999 18.6510 21.0015 4307.901 100
#> func8(x, y) 19.201 22.4500 64.33288 24.8510 27.5010 3776.101 100
Upvotes: 3
Reputation: 887088
One option is dapply
from collapse
library(collapse)
func3 <- function(x, y) {
dapply(y, function(a, b = x) {
sqrt(sum((a-b)^2))
}, MARGIN = 1)
}
Or may use vapply
func4 <- function(x, y) {
vapply(seq_len(nrow(y)), function(i, b = x) sqrt(sum((y[i,]-b)^2)), numeric(1))
}
Or may replicate the vector and use rowSums
after subtracting
func7 <- function(x, y) sqrt(rowSums((y-x[col(y)])^2))
microbenchmark::microbenchmark(func1(x, y), func3(x, y), func4(x, y), func7(x, y))
#Unit: microseconds
# expr min lq mean median uq max neval cld
# func1(x, y) 37.605 39.7475 61.17471 40.7595 42.1865 1955.888 100 a
# func3(x, y) 22.212 23.5945 68.63660 24.8320 25.8670 4333.933 100 a
# func4(x, y) 21.089 22.7930 24.11542 23.5945 24.2315 58.050 100 a
# func7(x, y) 7.731 8.9135 44.45935 10.0615 10.9500 3415.959 100 a
Upvotes: 3