skleene
skleene

Reputation: 389

Improving on the efficiency of randsample in MATLAB for a Markov chain simulation.

I am using matlab to simulate an accumulation process with several random walks that accumulate towards threshold in parallel. To select which random walk will increase at time t, randsample is used. If the vector V represents the active random walks and vector P represents the probability with which each random walk should be selected then the call to randsample looks like this:

randsample(V, 1, true, P);

The problem is that the simulations are slow, and randsample is the bottleneck. Approximately 80% of the runtime is dedicated to resolving the randsample call.

Is there a relatively straightforward way to improve upon the efficiency of randsample? Are there other alternatives that might improve the speed?

Upvotes: 0

Views: 1241

Answers (2)

Amro
Amro

Reputation: 124543

Like I mentioned in the comments, the bottleneck is properly caused by the fact that you are sampling one value at a time, it would be faster if you vectorize the randsample call (of course I am assuming that the probabilities vector is constant).

Here is a quick benchmark:

function testRandSample()
    v = 1:5;
    w = rand(numel(v),1); w = w ./ sum(w);
    n = 50000;

    % timeit
    t(1) = timeit(@() func1(v, w, n));
    t(2) = timeit(@() func2(v, w, n));
    t(3) = timeit(@() func3(v, w, n));
    disp(t)

    % check distribution of samples (should be close to w)
    tabulate(func1(v, w, n))
    tabulate(func2(v, w, n))
    tabulate(func3(v, w, n))
    disp(w*100)
end


function s = func1(v, w, n)
    s = randsample(v, n, true, w);
end

function s = func2(v, w, n)
    [~,idx] = histc(rand(n,1), [0;cumsum(w(:))./sum(w)]);
    s = v(idx);
end

function s = func3(v, w, n)
    cw = cumsum(w) / sum(w);
    s = zeros(n,1);
    for i=1:n
        s(i) = find(rand() <= cw, 1, 'first');
    end
    s = v(s);

    %s = v(arrayfun(@(~)find(rand() <= cw, 1, 'first'), 1:n));
end

The output (annotated):

% measured elapsed times for func1/2/3 respectively
  0.0016    0.0015    0.0790

% distribution of random sample from func1
  Value    Count   Percent
      1     4939      9.88%
      2    15049     30.10%
      3     7450     14.90%
      4    11824     23.65%
      5    10738     21.48%

% distribution of random sample from func2
  Value    Count   Percent
      1     4814      9.63%
      2    15263     30.53%
      3     7479     14.96%
      4    11743     23.49%
      5    10701     21.40%

% distribution of random sample from func3
  Value    Count   Percent
      1     4985      9.97%
      2    15132     30.26%
      3     7275     14.55%
      4    11905     23.81%
      5    10703     21.41%

% true population distribution
    9.7959
   30.4149
   14.7414
   23.4949
   21.5529

As you can see, randsample is pretty well optimized. The bottleneck you observed in your code is probably due lack of vectorization as I explained.

To see how slow it can get, replace func1 with a looped version sampling one value at-a-time:

function s = func1(v, w, n)
    s = zeros(n,1);
    for i=1:n
        s(i) = randsample(v, 1, true, w);
    end
end

Upvotes: 2

Luis Mendo
Luis Mendo

Reputation: 112659

Maybe this will be faster:

find(rand <= cumsum(P), 1) %// gives the same as randsample(V, 1, true, P)

I'm assuming P are probabilities, i.e. their sum is 1. Otherwise normalize P:

find(rand <= cumsum(P)/sum(P), 1) %// gives the same as randsample(V, 1, true, P)

If P is always the same, precompute cumsum(P)/sum(P) to save time:

cp = cumsum(P)/sum(P); %// precompute (just once)
find(rand <= cP, 1) %// gives the same as randsample(V, 1, true, P)

Upvotes: 1

Related Questions