Alex I
Alex I

Reputation: 20287

Efficiently picking random indexes from a numpy array?

I'd like to pick at random out of the indexes of those elements of a numpy array that meet a condition. My arrays are typically 2D, a few million elements total; the condition is computed over the whole array, and only relatively few elements (< a percent) come out true. I need to pick one element where the condition is true, at random. Because of the use of the data, the random choice has to be unbiased (every true element picked with the same probability) and I only pick one per array on each pass (so no reuse of any calculations).

Slow code which does the right thing by building a list of all candidate indexes explicitly:

# prepare sample data
img = np.zeros((2048,2048), dtype=bool)
for n in range(10000):
    i, j = np.random.randint(img.shape[0]), np.random.randint(img.shape[1])
    img[i,j] = True

def pick(img):
    indexes = np.argwhere(img)
    k = np.random.randint(len(indexes))
    return indexes[k]

pick(img) # around 8ms

This seems to take a stupidly long time to pick one element out of 10000. The culprit is, of course, np.argwhere() which is where most of the time is spend. I don't need the whole list this returns; I just need one element from a random shuffle of that list, and can stop the calculation early at that point.

How do I do the same thing, but faster?

P.S. The elements may be clustered - it is entirely possible for all of the true values to be in one corner of the array. So any speedup which relies on dividing areas probably won't work :)

Upvotes: 0

Views: 2028

Answers (2)

Paul Panzer
Paul Panzer

Reputation: 53029

It's probably much more efficient to randomly subsample, in fact, I see a ~300x speedup on the example.

import numpy as np
from timeit import timeit

# prepare sample data
img = np.zeros((2048,2048), dtype=bool)
for n in range(10000):
    i, j = np.random.randint(img.shape[0]), np.random.randint(img.shape[1])
    img[i,j] = True

def pick(img):
    indexes = np.argwhere(img)
    k = np.random.randint(len(indexes))
    return indexes[k]

def pp(img,maxiter=100,batchsize=1000):
    imf = img.reshape(-1)
    for i in range(maxiter):
        idx = np.random.randint(0,imf.size,batchsize)
        pick = np.argmax(imf[idx])
        if imf[idx[pick]]:
            return np.unravel_index(idx[pick],img.shape)
    else:
        raise RuntimeError("no luck")

print('OP',timeit(lambda:pick(img),number=100)*10,'ms')
print('pp',timeit(lambda:pp(img),number=100)*10,'ms')

# sanity check
samples = np.array([pp(img) for _ in range(200000)])
histo = np.bincount(np.ravel_multi_index((*samples.T,),img.shape))
assert (histo.nonzero()[0] == img.reshape(-1).nonzero()[0]).all()
print(np.bincount(histo)[1:])

Sample run:

OP 14.76260277966503 ms
pp 0.045300929923541844 ms
[  0   0   0   0   0   2   5  15  24  59  96 176 241 394 513 661 796 812
 908 922 822 751 648 548 457 365 265 155 141  88  55  29  26   9   5   1
   2   1]

The last output is (starting from 1) the number of True positions that were picked that many times out of 200000 trials, i.e. there are 2 positions that were picked 6 times, 5 that were picked 7 times etc. Expected is something that peaks at 20 and looks roughly brll shaped.

Upvotes: 1

alkasm
alkasm

Reputation: 23002

I was able to find a significant, but not quite an order of magnitude speedup with:

np.unravel_index(np.random.choice(np.flatnonzero(img)), img.shape)

Here np.flatnonzero (docs) gives the linearized index into the array of non-zero entries:

>>> np.flatnonzero(img)
array([    276,     548,    1053, ..., 4193808, 4194060, 4194198])

And then np.random.choice (docs) to choose just one of those values (using flatnonzero instead of nonzero lets me avoid having to choose one index per axis, and instead just use a single index):

>>> np.random.choice(np.flatnonzero(img))
3123039

Then we just need to convert between this linearized index and the multidimensional index, which can be achieved with np.unravel_index (docs):

>>> np.unravel_index(3123039, img.shape)
(1524, 1887)

With pick defined as in your example and pick2 as:

def pick2(img):
    return np.unravel_index(np.random.choice(np.flatnonzero(img)), img.shape)

the timeit module shows >5x speedup:

>>> timeit(lambda: pick(img), number=100)
1.208479189008358
>>> timeit(lambda: pick2(img), number=100)
0.17817635700339451

Upvotes: 1

Related Questions