Tom
Tom

Reputation: 105

Rcpp function to find the median, given a vector of values and their frequencies

I'm writing a function to find the median of a set of values. The data is presented as a vector of the unique values (call them 'values') and a vector of their frequencies ('freqs'). Frequently the frequencies are very high, so pasting them out uses an exorbitant amount of memory. I have a slow R implementation and it is the major bottleneck in my code, so I am writing a custom Rcpp function for use in an R/Bioconductor package. Bioconductor's site suggests not using C++11, so that is an issue for me.

My problem lies in trying to sort the two vectors together, according to the order of the values. In R, we can just use the order() function. I cannot seem to get this to work, despite following the advice on this question: C++ sorting and keeping track of indexes

The following lines are where the problem lies:

   // sort vector based on order of values
 IntegerVector idx_ord = std::sort(idx.begin(), idx.end(),
    bool (int i1, int i2) {return values[i1] < values[i2];});

And here is the full function, for anyone's interest. Any further tips would be greatly appreciated:

    #include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
double median_freq(NumericVector values, IntegerVector freqs) {
    int len = freqs.size();
    if (any(freqs!=0)){
        int med = 0;
        return med;
    }
    // filter out the zeros pre-sorting
    IntegerVector non_zeros;
    for (int i = 0; i < len; i++){
        if(freqs[i] != 0){
            non_zeros.push_back(i);
        }
    }
    freqs = freqs[non_zeros];
    values = values[non_zeros];
    // find the order of values
    // create integer vector of indices
    IntegerVector idx(len);
    for (int i = 0; i < len; ++i) idx[i] = i;

    // sort vector based on order of values
 IntegerVector idx_ord = std::sort(idx.begin(), idx.end(),
    bool (int i1, int i2) {return values[i1] < values[i2];});

    //apply to freqs and values
    freqs = freqs[idx_ord];
    values=values[idx_ord];
    IntegerVector cum_freqs(len);
    cum_freqs[0] = freqs[0];
    for (int i = 1; i < len; ++i) cum_freqs[i] = freqs[i] + cum_freqs[i-1];
    int total_freqs = cum_freqs[len-1];
    // split into odd and even frequencies and calculate the median
    if (total_freqs % 2 == 1) {
        int med_ind = (total_freqs + 1)/2 - 1; // C++ indexes from 0
        int i = 0;
        while ((i < len) && cum_freqs[i] < med_ind){
            i++;
        }
        double ret = values[i];
        return ret;
    } else {
        int med_ind_1 = total_freqs/2 - 1; // C++ indexes from 0
        int med_ind_2 = med_ind_1 + 1; // C++ indexes from 0
        int i = 0;
        while ((i < len) && cum_freqs[i] < med_ind_1){
            i++;
        }
        double ret_1 = values[i];
        i = 0;
        while ((i < len) && cum_freqs[i] < med_ind_2){
            i++;
        }
        double ret_2 = values[i];
        double ret = (ret_1 + ret_2)/2;
        return ret;
    }
}

For anyone using the RUnit testing framework, here are some basic unit tests:

test_median_freq <- function(){
    checkEquals(median_freq(1:10,1:10),7)
    checkEquals(median_freq(1:10,rep(1,10)),5.5)
    checkEquals(median_freq(2:6,c(1,2,1,45,2)),5)
}

Thanks!

Upvotes: 2

Views: 668

Answers (1)

josliber
josliber

Reputation: 44330

I would actually combine the value and frequency into a std::pair<double, int> and then just sort them with std::sort; in this way you always keep a value and its frequency together. This enables you to write much cleaner code because there isn't an additional set of indices floating around:

#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
double median_freq(NumericVector values, IntegerVector freqs) {
  const int len = freqs.size();
  std::vector<std::pair<double, int> > allDat;
  int freqSum = 0;
  for (int i=0; i < len; ++i) {
    allDat.push_back(std::pair<double, int>(values[i], freqs[i]));
    freqSum += freqs[i];
  }
  std::sort(allDat.begin(), allDat.end());
  int accum = 0;
  for (int i=0; i < len; ++i) {
    accum += allDat[i].second;
    if (freqSum % 2 == 0) {
      if (accum > freqSum / 2) {
        return allDat[i].first;
      } else if (accum == freqSum / 2) {
        return (allDat[i].first + allDat[i+1].first) / 2;
      }
    } else {
      if (accum >= (freqSum+1)/2) {
        return allDat[i].first;
      }
    }
  }
  return NA_REAL;  // Should not be reached
}

Try it out in R:

median_freq(1:10, 1:10)
# [1] 7
median_freq(1:10,rep(1,10))
# [1] 5.5
median_freq(2:6,c(1,2,1,45,2))
# [1] 5

We can also code up a naive R implementation to determine the efficiency gains that we get from using Rcpp:

med.freq.r <- function(values, freqs) {
  ord <- order(values)
  values <- values[ord]
  freqs <- freqs[ord]
  s <- sum(freqs)
  cs <- cumsum(freqs)
  idx <- min(which(cs >= s/2))
  if (s %% 2 == 0 && cs[idx] == s/2) {
    (values[idx] + values[idx+1]) / 2
  } else {
    values[idx]
  }
}
med.freq.r(1:10, 1:10)
# [1] 7
med.freq.r(1:10,rep(1,10))
# [1] 5.5
med.freq.r(2:6,c(1,2,1,45,2))
# [1] 5

To benchmark, let's look at a very large set of values:

set.seed(144)
values <- rnorm(1000000)
freqs <- sample(1:100, 1000000, replace=TRUE)
all.equal(median_freq(values, freqs), med.freq.r(values, freqs))
# [1] TRUE
library(microbenchmark)
microbenchmark(median_freq(values, freqs), med.freq.r(values, freqs))
# Unit: milliseconds
#                        expr      min       lq     mean   median       uq      max neval
#  median_freq(values, freqs) 128.5322 131.6095 146.8360 145.6389 159.6117 165.0306    10
#   med.freq.r(values, freqs) 715.2187 744.5709 776.0539 765.9178 817.7157 855.1898    10

For 1 million entries, the Rcpp solution is about 5x faster than the R solution; given the compilation overhead, that performance is only attractive if you're working on really huge vectors or if this is a frequently repeated option.

Linear-time approach

In general we know how to compute the median without sorting (for details, check out http://www.cc.gatech.edu/~mihail/medianCMU.pdf). While the algorithm is a bit more delicate than just sorting and iterating, it can yield significant speedups:

double fast_median_freq(NumericVector values, IntegerVector freqs) {
  const int len = freqs.size();
  std::vector<std::pair<double, int> > allDat;
  int freqSum = 0;
  for (int i=0; i < len; ++i) {
    allDat.push_back(std::pair<double, int>(values[i], freqs[i]));
    freqSum += freqs[i];
  }

  int target = freqSum / 2;
  int low = 0;
  int high = len-1;
  while (true) {
    // Random pivot; move to the end
    int rnd = low + (rand() % (high-low+1));
    std::swap(allDat[rnd], allDat[high]);

    // In-place pivot
    int highPos = low;  // Start of values higher than pivot
    int lowSum = 0;  // Sum of frequencies of elements below pivot
    for (int pos=low; pos < high; ++pos) {
      if (allDat[pos].first <= allDat[high].first) {
        lowSum += allDat[pos].second;
        std::swap(allDat[highPos], allDat[pos]);
        ++highPos;
      }
    }
    std::swap(allDat[highPos], allDat[high]);  // Move pivot to "highPos"

    // If we found the element then return; o/w recurse on proper side
    if (lowSum >= target) {
      // Recurse on lower elements
      high = highPos - 1;
    } else if (lowSum + allDat[highPos].second >= target) {
      // Return
      if (target < lowSum + allDat[highPos].second || freqSum % 2 == 1) {
        return allDat[highPos].first;
      } else {
        double nextHighest = std::min_element(allDat.begin() + highPos+1, allDat.begin() + len-1)->first;
        return (allDat[highPos].first + nextHighest) / 2;
      }
    } else {
      // Recurse on higher elements
      low = highPos + 1;
      target -= (lowSum + allDat[highPos].second);
    }
  }
}

Benchmarking:

all.equal(median_freq(values, freqs), fast_median_freq(values, freqs))
[1] TRUE
microbenchmark(median_freq(values, freqs), med.freq.r(values, freqs), fast_median_freq(values, freqs), times=10)
# Unit: milliseconds
#                             expr       min        lq      mean    median        uq       max neval
#       median_freq(values, freqs) 119.57989 122.48622 130.47841 130.48811 132.75421 146.36136    10
#        med.freq.r(values, freqs) 665.72803 690.15016 708.05729 702.65885 731.83936 749.36834    10
#  fast_median_freq(values, freqs)  24.37572  29.39641  31.86144  31.77459  34.88418  36.81606    10

The linear approach offers a 4x speedup over the sort-then-iterate Rcpp solution and a 20x speedup over the base R solution.

Upvotes: 5

Related Questions