Robert Hickman
Robert Hickman

Reputation: 907

Find index of all max/min values in vector in Rcpp

Let's say I have a vector

v = c(1,2,3)

I can easily find which element is the max using

cppFunction('int which_maxCpp(NumericVector v) {
  int z = which_max(v);
  return z;
}')

which_maxCpp(v)

2

However, if I have a vector such as

v2 = c(1,2,3,1,2,3)

I also get

which_maxCpp(v2)

2

whereas I should find that index 2 and index 5 (or index 3 and index 6 if using 1-indexing) are equal to the max in the vector

is there a way to get which_max (or which_min) to find the index of all min/max elements of a vector, or is another (I'd assume native C++) function needed?

Upvotes: 3

Views: 1450

Answers (1)

SymbolixAU
SymbolixAU

Reputation: 26258

I don't know about a native function, but a loop is fairly straight-forward to write.

Here are three versions.

Two which find the Rcpp::max() of the vector, then find the indices of the vector which match this max. One uses a pre-allocated Rcpp::IntegerVector() to store the result, which is then subset to remove the extra 'unused' zeroes. The other uses a std::vector< int > with a .push_back() to store the results.

library(Rcpp)

cppFunction('IntegerVector which_maxCpp1(NumericVector v) {
  double m = Rcpp::max(v);
  Rcpp::IntegerVector res( v.size() );  // pre-allocate result vector

  int i;
  int counter = 0;
  for( i = 0; i < v.size(); ++i) {
    if( v[i] == m ) {
      res[ counter ] = i;
      counter++;
    }
  }
  counter--;
  Rcpp::Range rng(0, counter);  
  return res[rng];
}')

v = c(1,2,3,1,2,3)

which_maxCpp(v)
# [1] 2 5
cppFunction('IntegerVector which_maxCpp2(NumericVector v) {
  double m = Rcpp::max(v);
  std::vector< int > res;

  int i;
  for( i = 0; i < v.size(); ++i) {
    if( v[i] == m ) {
      res.push_back( i );
    }
  }
  Rcpp::IntegerVector iv( res.begin(), res.end() );
  return iv;
}')

which_maxCpp(v)
# [1] 2 5

The third option avoids the double-pass over the vector by finding both the max, and keeping track of the indices in the one loop at the same time.

cppFunction('IntegerVector which_maxCpp3(NumericVector v) {

  double current_max = v[0];
  int n = v.size();
  std::vector< int > res;
  res.push_back( 0 );
  int i;

  for( i = 1; i < n; ++i) {
    double x = v[i];
    if( x > current_max ) {
      res.clear();
      current_max = x;
      res.push_back( i );
    } else if ( x == current_max ) {
      res.push_back( i );
    }
  }
  Rcpp::IntegerVector iv( res.begin(), res.end() );
  return iv;
}')

Benchmarking

Here are some benchmarks showing how these functions stack-up against the base R approach.

library(microbenchmark)

x <- sample(1:100, size = 1e6, replace = T)

microbenchmark(
  iv = { which_maxCpp1(x) },
  stl = { which_maxCpp2(x) },
  max = { which_maxCpp3(x) },
  r = { which( x == max(x)) } 
)

# Unit: milliseconds
# expr      min        lq      mean    median       uq        max neval
#   iv 6.638583 10.617945 14.028378 10.956616 11.63981 165.719783   100
#  stl 6.830686  9.506639  9.787291  9.744488 10.17247  11.275061   100
#  max 3.161913  5.690886  5.926433  5.913899  6.19489   7.427020   100
#    r 4.044166  5.558075  5.819701  5.719940  6.00547   7.080742   100

Upvotes: 5

Related Questions