eduardokapp
eduardokapp

Reputation: 1751

Dist function between matrix objects in R

I have a very simple problem.

Given a N dimension point (say, a vector where each element represents a dimension) represented by x and a MxN dimension matrix (or a group of M points that have N dimensions!) represented by y.

set.seed(999)
data <- matrix(runif(1100), nrow = 11, ncol = 10)

x <- data[1, ]
y <- data[2:nrow(data), ]

I want to calculate a distance measure between x and every point of y. I know a simple way of doing that is to do:

distances <- dist(rbind(x, y))

However, I believe this is not very efficient for this specific case, for the following reasons:

  1. I need to use rbind, which is very memory costly.
  2. dist calculates the distance between every point, but I'm only interested in 10 of those distances, or simply the distance between x and every point of y. I'm not interested in the internal distances between y points.
  3. Because of (2) I need to manually select the last line of the dist matrix to get the distances I actually need.

One possible solution I thought of was to apply the distance measurement manually looping through y.

distances <- apply(y, MARGIN = 1, function(a, b = x) {
   sqrt(sum((a - b)^2))
})

However, when timing both approaches, I get:

func1 <- function(x, y) {
  apply(y, MARGIN = 1, function(a, b = x) {
    sqrt(sum((a - b)^2))
  })
}

func2 <- function(x, y) {
  dist(rbind(x, y))
}

microbenchmark::microbenchmark(
  func1(x, y),
  func2(x, y)
)

Unit: microseconds
        expr    min     lq     mean median      uq      max neval
 func1(x, y) 29.602 30.450 61.21791 31.301 32.3510 2916.101   100
 func2(x, y) 15.101 15.801 28.55304 17.201 17.7015 1143.001   100

So my question here is: is there a way to solve this problem faster than using dist?

Upvotes: 3

Views: 413

Answers (3)

Joel Kandiah
Joel Kandiah

Reputation: 1525

Update 2: If we assume we have complete data and the accuracy is with we can implement an even faster version of the distances using rcpp. I have added this below and the fastest version uses the bytecode compiler. For those with experience with RcppParallel this can likely be improved further.

Update: The function rdist from the fields package is by far the fastest method found so far. (see Calculating all distances between one point and a group of points efficiently in R). It appears to be fastest when not using the bytecode compiler also.

Briefly performing some testing on the previous results I obtained that vapply is faster when using the bytecode compiler than all other methods (after the first run when it compiles the functions for the first time this is why the maxtime is greater during the bytecode runs).

I have tried the methods from @akrun and @ThomasIsCoding here also.

library(microbenchmark)
library(compiler)
library(collapse)
library(fields)
library(Rcpp)

set.seed(999)
data <- matrix(runif(1100), nrow = 11, ncol = 10)

x <- data[1, ]
y <- data[2:nrow(data), ]

distances <- dist(rbind(x, y))

func1 <- function(x, y) {
  apply(y, MARGIN = 1, function(a, b = x) {
    sqrt(sum((a - b)^2))
  })
}

func2 <- function(x, y) {
  dist(rbind(x, y))
}

func3 <- function(x, y) {
  dapply(y, function(a, b = x) {
    sqrt(sum((a-b)^2))
  }, MARGIN = 1)
}

func4 <- function(x, y) {
  vapply(seq_len(nrow(y)), function(i, b = x) sqrt(sum((y[i,]-b)^2)), numeric(1))
}

func5 <- function(x, y) {
  rdist(rbind(x, y))
}

cppFunction('NumericVector func6(NumericVector x, NumericVector y) {
  int n = x.size();
  int n2 = y.size();
  
  int maxiters = n2/n;
  
  NumericVector results(maxiters);
  
  for(int i = 0; i < maxiters; i++) {
    results[i] = 0;
    for(int j = 0; j < n; j++) {
      double val = x[j] - y[j * maxiters + i];
      results[i] += val * val;
    }
    results[i] = sqrt(results[i]);
  }
  
  return results;
  
}')

func7 <- function(x, y) sqrt(rowSums((y-x[col(y)])^2))

func8 <- function(x, y) sqrt(colSums((t(y) - x)^2))

compiler::enableJIT(0)
#> [1] 3

microbenchmark::microbenchmark(
  func1(x, y),
  func2(x, y),
  func3(x, y),
  func4(x, y),
  func5(x, y),
  func6(x, y),
  func7(x, y),
  func8(x, y)
)
#>Unit: microseconds
#>        expr    min      lq     mean  median      uq      max neval
#> func1(x, y) 37.001 42.8010 50.53103 45.4520 53.6515  138.302   100
#> func2(x, y) 20.201 25.3510 30.23096 27.8515 31.4010   70.401   100
#> func3(x, y) 23.901 27.6510 55.45699 30.0010 35.7505 2248.902   100
#> func4(x, y) 20.501 23.2010 28.20101 24.6020 31.4010  119.501   100
#> func5(x, y)  6.100  8.6020 19.27804  9.6515 11.4510  891.001   100
#> func6(x, y)  1.501  2.4010 11.60706  2.9010  3.4510  848.102   100
#> func7(x, y) 14.401 17.2505 27.73793 19.7510 23.2510  596.002   100
#> func8(x, y) 18.901 22.5510 27.91699 24.9015 29.3010   73.301   100


compiler::enableJIT(3)
#> [1] 0

microbenchmark::microbenchmark(
  func1(x, y),
  func2(x, y),
  func3(x, y),
  func4(x, y),
  func5(x, y),
  func6(x, y),
  func7(x, y),
  func8(x, y)
  
)
#>Unit: microseconds
#>        expr    min      lq     mean  median      uq      max neval
#> func1(x, y) 32.100 35.9510 85.49213 39.4015 44.2510 4298.002   100
#> func2(x, y) 19.701 23.6010 45.11697 26.0505 29.6005 1732.702   100
#> func3(x, y) 19.801 22.2515 76.96108 24.8010 27.6510 5023.201   100
#> func4(x, y) 16.302 19.2510 77.46094 20.3010 21.8005 5564.701   100
#> func5(x, y)  6.201  8.5010 41.53397  9.4510 11.0510 3032.301   100
#> func6(x, y)  1.401  2.3010 13.95802  2.7005  3.0020 1101.801   100
#> func7(x, y) 14.201 16.7010 64.09999 18.6510 21.0015 4307.901   100
#> func8(x, y) 19.201 22.4500 64.33288 24.8510 27.5010 3776.101   100

Created on 2021-04-04 by the reprex package (v2.0.0)

Just the results

#ordinary compiler

#>Unit: microseconds
#>        expr    min      lq     mean  median      uq      max neval
#> func1(x, y) 37.001 42.8010 50.53103 45.4520 53.6515  138.302   100
#> func2(x, y) 20.201 25.3510 30.23096 27.8515 31.4010   70.401   100
#> func3(x, y) 23.901 27.6510 55.45699 30.0010 35.7505 2248.902   100
#> func4(x, y) 20.501 23.2010 28.20101 24.6020 31.4010  119.501   100
#> func5(x, y)  6.100  8.6020 19.27804  9.6515 11.4510  891.001   100
#> func6(x, y)  1.501  2.4010 11.60706  2.9010  3.4510  848.102   100
#> func7(x, y) 14.401 17.2505 27.73793 19.7510 23.2510  596.002   100
#> func8(x, y) 18.901 22.5510 27.91699 24.9015 29.3010   73.301   100

#bytecode compiler

#>Unit: microseconds
#>        expr    min      lq     mean  median      uq      max neval
#> func1(x, y) 32.100 35.9510 85.49213 39.4015 44.2510 4298.002   100
#> func2(x, y) 19.701 23.6010 45.11697 26.0505 29.6005 1732.702   100
#> func3(x, y) 19.801 22.2515 76.96108 24.8010 27.6510 5023.201   100
#> func4(x, y) 16.302 19.2510 77.46094 20.3010 21.8005 5564.701   100
#> func5(x, y)  6.201  8.5010 41.53397  9.4510 11.0510 3032.301   100
#> func6(x, y)  1.401  2.3010 13.95802  2.7005  3.0020 1101.801   100
#> func7(x, y) 14.201 16.7010 64.09999 18.6510 21.0015 4307.901   100
#> func8(x, y) 19.201 22.4500 64.33288 24.8510 27.5010 3776.101   100

Upvotes: 3

ThomasIsCoding
ThomasIsCoding

Reputation: 101327

Here is another base R

sqrt(colSums((t(y) - x)^2))

Upvotes: 1

akrun
akrun

Reputation: 887088

One option is dapply from collapse

 library(collapse)
 func3 <- function(x, y) {
     dapply(y, function(a, b = x) {
             sqrt(sum((a-b)^2))
          }, MARGIN = 1)
  }

Or may use vapply

func4 <- function(x, y) {
  vapply(seq_len(nrow(y)), function(i, b = x) sqrt(sum((y[i,]-b)^2)), numeric(1))
 }

Or may replicate the vector and use rowSums after subtracting

func7 <- function(x, y) sqrt(rowSums((y-x[col(y)])^2))
microbenchmark::microbenchmark(func1(x, y), func3(x, y), func4(x, y), func7(x, y))
#Unit: microseconds
#        expr    min      lq     mean  median      uq      max neval cld
# func1(x, y) 37.605 39.7475 61.17471 40.7595 42.1865 1955.888   100   a
# func3(x, y) 22.212 23.5945 68.63660 24.8320 25.8670 4333.933   100   a
# func4(x, y) 21.089 22.7930 24.11542 23.5945 24.2315   58.050   100   a
# func7(x, y)  7.731  8.9135 44.45935 10.0615 10.9500 3415.959   100   a

Upvotes: 3

Related Questions