user3067923
user3067923

Reputation: 457

making function that checks if vector exists in matrix faster

I have the following function (funtest) to test if a specific vector exists in a matrix. The vector will always be length 2 and the matrix will always have two columns. The function works fine, I would just like to make it faster (ideally much faster), because my matrices can have hundreds to thousands of rows.

x = c(1,2)

set.seed(100)
m <- matrix(sample(c(1,-2,3,4), 500*2, replace=TRUE), ncol=2)

funtest(m,x)
[1] TRUE 

This is how fast it currently is

library(microbenchmark)
microbenchmark(funtest(m, x), times=100)
Unit: milliseconds
          expr      min       lq     mean   median       uq      max
 funtest(m, x) 1.501247 1.536157 1.674668 1.567826 1.708293 2.900046
 neval
   100

This is the function

funtest = function(m, x) {
    out = any(apply(m,1,function(n,x) all(n==x),x=x))
    return(out)
}

Upvotes: 1

Views: 183

Answers (3)

Sathish
Sathish

Reputation: 12723

base::bitwXor() will produce 0 for a match between two integers.

Note: bitwXor() works for integers only

EDIT : Added comparison with 0 from bitwXor and added data.table solution

library(microbenchmark)
set.seed(100)
m <- matrix(sample(c(1,-2,3,4), 500*2, replace=TRUE), ncol=2)

fun1 <- function(m,x) {any(apply(m,1,function(n,x) all(n==x),x=x))}
fun2 <- function(m,x) {paste(x[1], x[2], sep='&') %in% paste(m[,1], m[,2], sep='&')}
fun3 <- function(m,x) {any((bitwXor(m[,1], x[1]) == 0) & (bitwXor(m[,2], x[2]) == 0))}
fun4 <- function(m,x) {setDT(m)[X1 == x[1] & X2 == x[2], .N > 0]}

x <-  c(1,2)

microbenchmark(fun1(m,x),     # @user3067923
               fun2(m,x),     # @Zheyuan Li
               rcppFn(m, x),  # @jav
               fun3(m,x),
               times = 1000)

# Unit: microseconds
#         expr      min       lq       mean   median       uq      max neval
#   fun1(m, x) 1802.483 1920.007 2156.93459 1995.865 2094.820 9915.013  1000
#   fun2(m, x) 1540.716 1602.534 1674.39556 1641.256 1702.848 2832.344  1000
# rcppFn(m, x)   14.040   16.305   23.43586   21.739   29.439   95.107  1000
#   fun3(m, x)   70.650   76.992   86.36290   82.879   88.766  314.303  1000

Data.Table solution:

library(data.table)
m <- data.frame(m)
microbenchmark(fun4(m,x), times = 1000)

# Unit: microseconds
#       expr     min       lq     mean median      uq      max neval
# fun4(m, x) 836.026 887.6555 985.8596 920.49 968.269 9025.546  1000

Upvotes: 3

jav
jav

Reputation: 1495

Here's a Rcpp (specifically Rcpp Armadillo) approach. Benchmarks are given at the end:

# Import the relevant packages (All for compiling the C++ code inline)
library(Rcpp)
library(RcppArmadillo)
library(inline)

# We need to include these namespaces in the C++ code 
includes <- '
using namespace Rcpp;
using namespace arma;
'

# This is the main C++ function 
# We cast 'm' as an Armadillo matrix 'm1' and compute the number of rows 'numRows'
# We cast 'x' as a row vector 'x1'
# We then loop through the rows of the matrix 
# As soon as we find a matching row (anyEqual = TRUE), we stop and return TRUE
# If no matching row is found, then anyEqual = FALSE and we return FALSE
# Note: Within the for loop, we do an elementwise comparison of a row of m1 to x1
# If the row is equal to x1, then the sum of the elementwise comparision should equal the number of elements of x1
src <- '
mat m1 = as<mat>(m); 
int numRows = m1.n_rows;
rowvec x1 = as<rowvec>(x);
bool anyEqual = FALSE;
for (int i = 0; i < numRows & !anyEqual; i++){
    anyEqual = (sum(m1.row(i) == x1) == x1.size());
}
return(wrap(anyEqual));
'

# Here, we compile the function above
# Do this once (in a given R session) and use it as many times as desired
rcppFn <- cxxfunction(signature(m="numeric", x="numeric"), src, plugin='RcppArmadillo', includes)

Benchmarks are here: (Edit: I've added a benchmark for @zheyuan-li very simple solution below too; it is called pasteFn)

# Your function is called funtest
# Rcpp function is rcppFn
# Zheyuan's solution is pasteFn
microbenchmark(funtest(m, x), rcppFn(m, x), pasteFn(m, x), times=100, unit = "ms")
Unit: milliseconds
          expr      min        lq       mean    median        uq      max neval
 funtest(m, x) 1.127903 1.1984755 1.30559130 1.2514455 1.3431040 2.641258   100
  rcppFn(m, x) 0.005420 0.0061355 0.00879676 0.0073660 0.0084130 0.030305   100
 pasteFn(m, x) 0.741269 0.7610905 0.79174042 0.7752145 0.8228895 0.894389   100

Edit: If you would like to use a matrix 'x' instead, the following source code should work

src <- '
mat m1 = as<mat>(m); 
int numRows = m1.n_rows;
mat x1 = as<mat>(x);
vec anyEqual = zeros<vec>(x1.n_rows);
for (int j = 0; j < x1.n_rows; j++){
for (int i = 0; i < numRows & !anyEqual(j); i++){
anyEqual(j) = (sum(m1.row(i) == x1.row(j)) == x1.n_cols);
}
}
return(wrap(anyEqual));
'

Here, I am just checking for each row of x, whether it exists in m. Very similar to the original code except that have one extra for loop. It would return 1 or 0 depending on if there's a match (not experienced enough with RcppArmadillo to create a vector of bools).

Upvotes: 3

Zheyuan Li
Zheyuan Li

Reputation: 73385

How about

paste(x[1], x[2], sep='&') %in% paste(m[,1], m[,2], sep='&')

This should be super efficient! It is based on matching. As soon as the first match is found, no further search will be done!

However I am sure this is not the fastest. The optimal solution is to write this operation in C code with a single while loop. But, the potential speedup factor should be no more than 2.

Upvotes: 3

Related Questions