Reputation: 24651
Python has my_sample = random.sample(range(100), 10)
to randomly sample without replacement from [0, 100)
.
Suppose I have sampled n
such numbers and now I want to sample one more without replacement (without including any of the previously sampled n
), how to do so super efficiently?
Upvotes: 36
Views: 41253
Reputation: 2673
It's surprising this is not already implemented in one of the core functions, but here is the clean version, that returns the sampled values and the list without replacement:
def sample_n_points_without_replacement(n, set_of_points):
sampled_point_indices = random.sample(range(len(set_of_points)), n)
sampled_point_indices.sort(reverse=True)
sampled_points = [set_of_points[sampled_point_index] for sampled_point_index in sampled_point_indices]
for sampled_point_index in sampled_point_indices:
del(set_of_points[sampled_point_index])
return sampled_points, set_of_points
Upvotes: 1
Reputation: 1964
This is a side note: suppose you want to solve exactly the same problem of sampling without replacement on a list (that I'll call sample_space
), but instead of sampling uniformly over the set of elements you have not sampled already, you are given an initial probability distribution p
that tells you the probability of sampling the i^th
element of the distribution, were you to sample over the whole space.
Then the following implementation using numpy is numerically stable:
import numpy as np
def iterative_sampler(sample_space, p=None):
"""
Samples elements from a sample space (a list)
with a given probability distribution p (numPy array)
without replacement. If called until StopIteration is raised,
effectively produces a permutation of the sample space.
"""
if p is None:
p = np.array([1/len(sample_space) for _ in sample_space])
try:
assert isinstance(sample_space, list)
assert isinstance(p, np.ndarray)
except AssertionError:
raise TypeError("Required types: \nsample_space: list \np type: np.ndarray")
# Main loop
n = len(sample_space)
idxs_left = list(range(n))
for i in range(n):
idx = np.random.choice(
range(n-i),
p= p[idxs_left] / p[idxs_left].sum()
)
yield sample_space[idxs_left[idx]]
del idxs_left[idx]
It's short and concise, I like it. Let me know what you guys think!
Upvotes: 0
Reputation: 70735
This is a rewritten version of @necromancer's cool solution. Wraps it in a class to make it much easier to use correctly, and uses more dict methods to cut the lines of code.
from random import randrange
class Sampler:
def __init__(self, n):
self.n = n # number remaining from original range(n)
# i is a key iff i < n and i already returned;
# in that case, state[i] is a value to return
# instead of i.
self.state = dict()
def get(self):
n = self.n
if n <= 0:
raise ValueError("range exhausted")
result = i = randrange(n)
state = self.state
# Most of the fiddling here is just to get
# rid of state[n-1] (if it exists). It's a
# space optimization.
if i == n - 1:
if i in state:
result = state.pop(i)
elif i in state:
result = state[i]
if n - 1 in state:
state[i] = state.pop(n - 1)
else:
state[i] = n - 1
elif n - 1 in state:
state[i] = state.pop(n - 1)
else:
state[i] = n - 1
self.n = n-1
return result
Here's a basic driver:
s = Sampler(100)
allx = [s.get() for _ in range(100)]
assert sorted(allx) == list(range(100))
from collections import Counter
c = Counter()
for i in range(6000):
s = Sampler(3)
one = tuple(s.get() for _ in range(3))
c[one] += 1
for k, v in sorted(c.items()):
print(k, v)
and sample output:
(0, 1, 2) 1001
(0, 2, 1) 991
(1, 0, 2) 995
(1, 2, 0) 1044
(2, 0, 1) 950
(2, 1, 0) 1019
By eyeball, that distribution is fine (run a chi-squared test if you're skeptical). Some of the solutions here don't give each permutation with equal probability (even though they return each k-subset of n with equal probability), so are unlike random.sample()
in that respect.
Upvotes: 5
Reputation: 70833
This is my version of the Knuth shuffle, that was first posted by Tim Peters, prettified by Eric and then nicely space-optimized by necromancer.
This is based on Eric’s version, since I indeed found his code very pretty :).
import random
def shuffle_gen(n):
# this is used like a range(n) list, but we don’t store
# those entries where state[i] = i.
state = dict()
for remaining in xrange(n, 0, -1):
i = random.randrange(remaining)
yield state.get(i,i)
state[i] = state.get(remaining - 1,remaining - 1)
# Cleanup – we don’t need this information anymore
state.pop(remaining - 1, None)
usage:
out = []
gen = shuffle_gen(100)
for n in range(100):
out.append(gen.next())
print out, len(set(out))
Upvotes: 6
Reputation: 70833
Note to readers from OP: Please consider looking at the originally accepted answer to understand the logic, and then understand this answer.
Aaaaaand for completeness sake: This is the concept of necromancer’s answer, but adapted so it takes a list of forbidden numbers as input. This is just the same code as in my previous answer, but we build a state from forbid
, before we generate numbers.
O(f+k)
and memory O(f+k)
. Obviously this is the fastest thing possible without requirements towards the format of forbid
(sorted/set). I think this makes this a winner in some way ^^.forbid
is a set, the repeated guessing method is faster with O(k⋅n/(n-(f+k)))
, which is very close to O(k)
for f+k
not very close to n
.forbid
is sorted, my ridiculous algorithm is faster with:import random
def sample_gen(n, forbid):
state = dict()
track = dict()
for (i, o) in enumerate(forbid):
x = track.get(o, o)
t = state.get(n-i-1, n-i-1)
state[x] = t
track[t] = x
state.pop(n-i-1, None)
track.pop(o, None)
del track
for remaining in xrange(n-len(forbid), 0, -1):
i = random.randrange(remaining)
yield state.get(i, i)
state[i] = state.get(remaining - 1, remaining - 1)
state.pop(remaining - 1, None)
usage:
gen = sample_gen(10, [1, 2, 4, 8])
print gen.next()
print gen.next()
print gen.next()
print gen.next()
Upvotes: 11
Reputation: 60207
If the number sampled is much less than the population, just sample, check if it's been chosen and repeat while so. This might sound silly, but you've got an exponentially decaying possibility of choosing the same number, so it's much faster than O(n)
if you've got even a small percentage unchosen.
Python uses a Mersenne Twister as its PRNG, which is goodadequate. We can use something else entirely to be able to generate non-overlapping numbers in a predictable manner.
Quadratic residues, x² mod p
, are unique when 2x < p
and p
is a prime.
If you "flip" the residue, p - (x² % p)
, given this time also that p = 3 mod 4
, the results will be the remaining spaces.
This isn't a very convincing numeric spread, so you can increase the power, add some fudge constants and then the distribution is pretty good.
First we need to generate primes:
from itertools import count
from math import ceil
from random import randrange
def modprime_at_least(number):
if number <= 2:
return 2
number = (number // 4 * 4) + 3
for number in count(number, 4):
if all(number % factor for factor in range(3, ceil(number ** 0.5)+1, 2)):
return number
You might worry about the cost of generating the primes. For 10⁶ elements this takes a tenth of a millisecond. Running [None] * 10**6
takes longer than that, and since it's only calculated once, this isn't a real problem.
Further, the algorithm doesn't need an exact value for the prime; is only needs something that is at most a constant factor larger than the input number. This is possible by saving a list of values and searching them. If you do a linear scan, that is O(log number)
and if you do a binary search it is O(log number of cached primes)
. In fact, if you use galloping you can bring this down to O(log log number)
, which is basically constant (log log googol = 2
).
Then we implement the generator
def sample_generator(up_to):
prime = modprime_at_least(up_to+1)
# Fudge to make it less predictable
fudge_power = 2**randrange(7, 11)
fudge_constant = randrange(prime//2, prime)
fudge_factor = randrange(prime//2, prime)
def permute(x):
permuted = pow(x, fudge_power, prime)
return permuted if 2*x <= prime else prime - permuted
for x in range(prime):
res = (permute(x) + fudge_constant) % prime
res = permute((res * fudge_factor) % prime)
if res < up_to:
yield res
And check that it works:
set(sample_generator(10000)) ^ set(range(10000))
#>>> set()
Now, the lovely thing about this is that if you ignore the primacy test, which is approximately O(√n)
where n
is the number of elements, this algorithm has time complexity O(k)
, where k
is the sample sizeit's and O(1)
memory usage! Technically this is O(√n + k)
, but practically it is O(k)
.
You do not require a proven PRNG. This PRNG is far better then linear congruential generator (which is popular; Java uses it) but it's not as proven as a Mersenne Twister.
You do not first generate any items with a different function. This avoids duplicates through mathematics, not checks. Next section I show how to remove this restriction.
The short method must be insufficient (k
must approach n
). If k
is only half n
, just go with my original suggestion.
Extreme memory savings. This takes constant memory... not even O(k)
!
Constant time to generate the next item. This is actually rather fast in constant terms, too: it's not as fast as the built-in Mersenne Twister but it's within a factor of 2.
Coolness.
To remove this requirement:
You do not first generate any items with a different function. This avoids duplicates through mathematics, not checks.
I have made the best possible algorithm in time and space complexity, which is a simple extension of my previous generator.
Here's the rundown (n
is the length of the pool of numbers, k
is the number of "foreign" keys):
O(√n)
; O(log log n)
for all reasonable inputsThis is the only factor of my algorithm that technically isn't perfect with regards to algorithmic complexity, thanks to the O(√n)
cost. In reality this won't be problematic because precalculation brings it down to O(log log n)
which is immeasurably close to constant time.
The cost is amortized free if you exhaust the iterable by any fixed percentage.
This is not a practical problem.
O(1)
key generation timeObviously this cannot be improved upon.
O(k)
key generation timeIf you have keys generated from the outside, with only the requirement that it must not be a key that this generator has already produced, these are to be called "foreign keys". Foreign keys are assumed to be totally random. As such, any function that is able to select items from the pool can do so.
Because there can be any number of foreign keys and they can be totally random, the worst case for a perfect algorithm is O(k)
.
O(k)
If the foreign keys are assumed totally independent, each represents a distinct item of information. Hence all keys must be stored. The algorithm happens to discard keys whenever it sees one, so the memory cost will clear over the lifetime of the generator.
Well, it's both of my algorithms. It's actually quite simple:
def sample_generator(up_to, previously_chosen=set(), *, prune=True):
prime = modprime_at_least(up_to+1)
# Fudge to make it less predictable
fudge_power = 2**randrange(7, 11)
fudge_constant = randrange(prime//2, prime)
fudge_factor = randrange(prime//2, prime)
def permute(x):
permuted = pow(x, fudge_power, prime)
return permuted if 2*x <= prime else prime - permuted
for x in range(prime):
res = (permute(x) + fudge_constant) % prime
res = permute((res * fudge_factor) % prime)
if res in previously_chosen:
if prune:
previously_chosen.remove(res)
elif res < up_to:
yield res
The change is as simple as adding:
if res in previously_chosen:
previously_chosen.remove(res)
You can add to previously_chosen
at any time by adding to the set
that you passed in. In fact, you can also remove from the set in order to add back to the potential pool, although this will only work if sample_generator
has not yet yielded it or skipped it with prune=False
.
So there is is. It's easy to see that it fulfils all of the requirements, and it's easy to see that the requirements are absolute. Note that if you don't have a set, it still meets its worst cases by converting the input to a set, although it increases overhead.
I became curious how good this PRNG actually is, statistically speaking.
Some quick searches lead me to create these three tests, which all seem to show good results!
Firstly, some random numbers:
N = 1000000
my_gen = list(sample_generator(N))
target = list(range(N))
random.shuffle(target)
control = list(range(N))
random.shuffle(control)
These are "shuffled" lists of 10⁶ numbers from 0
to 10⁶-1
, one using our fun fudged PRNG, the other using a Mersenne Twister as a baseline. The third is the control.
Here's a test which looks at the average distance between two random numbers along the line. The differences are compared with the control:
from collections import Counter
def birthdat_calc(randoms):
return Counter(abs(r1-r2)//10000 for r1, r2 in zip(randoms, randoms[1:]))
def birthday_compare(randoms_1, randoms_2):
birthday_1 = sorted(birthdat_calc(randoms_1).items())
birthday_2 = sorted(birthdat_calc(randoms_2).items())
return sum(abs(n1 - n2) for (i1, n1), (i2, n2) in zip(birthday_1, birthday_2))
print(birthday_compare(my_gen, target), birthday_compare(control, target))
#>>> 9514 10136
This is less than the variance of each.
Here's a test which takes 5 numbers in turn and sees what order the elements are in. They should be evenly distributed between all 120 possible orders.
def permutations_calc(randoms):
permutations = Counter()
for items in zip(*[iter(randoms)]*5):
sorteditems = sorted(items)
permutations[tuple(sorteditems.index(item) for item in items)] += 1
return permutations
def permutations_compare(randoms_1, randoms_2):
permutations_1 = permutations_calc(randoms_1)
permutations_2 = permutations_calc(randoms_2)
keys = sorted(permutations_1.keys() | permutations_2.keys())
return sum(abs(permutations_1[key] - permutations_2[key]) for key in keys)
print(permutations_compare(my_gen, target), permutations_compare(control, target))
#>>> 5324 5368
This is again less than the variance of each.
Here's a test that sees how long "runs" are, aka. sections of consecutive increases or decreases.
def runs_calc(randoms):
runs = Counter()
run = 0
for item in randoms:
if run == 0:
run = 1
elif run == 1:
run = 2
increasing = item > last
else:
if (item > last) == increasing:
run += 1
else:
runs[run] += 1
run = 0
last = item
return runs
def runs_compare(randoms_1, randoms_2):
runs_1 = runs_calc(randoms_1)
runs_2 = runs_calc(randoms_2)
keys = sorted(runs_1.keys() | runs_2.keys())
return sum(abs(runs_1[key] - runs_2[key]) for key in keys)
print(runs_compare(my_gen, target), runs_compare(control, target))
#>>> 1270 975
The variance here is very large, and over several executions I have seems an even-ish spread of both. As such, this test is passed.
A Linear Congruential Generator was mentioned to me, as possibly "more fruitful". I have made a badly implemented LCG of my own, to see whether this is an accurate statement.
LCGs, AFAICT, are like normal generators in that they're not made to be cyclic. Therefore most references I looked at, aka. Wikipedia, covered only what defines the period, not how to make a strong LCG of a specific period. This may have affected results.
Here goes:
from operator import mul
from functools import reduce
# Credit http://stackoverflow.com/a/16996439/1763356
# Meta: Also Tobias Kienzler seems to have credit for my
# edit to the post, what's up with that?
def factors(n):
d = 2
while d**2 <= n:
while not n % d:
yield d
n //= d
d += 1
if n > 1:
yield n
def sample_generator3(up_to):
for modulier in count(up_to):
modulier_factors = set(factors(modulier))
multiplier = reduce(mul, modulier_factors)
if not modulier % 4:
multiplier *= 2
if multiplier < modulier - 1:
multiplier += 1
break
x = randrange(0, up_to)
fudge_constant = random.randrange(0, modulier)
for modfact in modulier_factors:
while not fudge_constant % modfact:
fudge_constant //= modfact
for _ in range(modulier):
if x < up_to:
yield x
x = (x * multiplier + fudge_constant) % modulier
We no longer check for primes, but we do need to do some odd things with factors.
modulier ≥ up_to > multiplier, fudge_constant > 0
a - 1
must be divisible by every factor in modulier
...fudge_constant
must be coprime with modulier
Note that these aren't rules for a LCG but a LCG with full period, which is obviously equal to the mod
ulier.
I did it as such:
modulier
at least up_to
, stopping when the conditions are satisfied
𝐅
multiplier
be the product of 𝐅
with duplicates removedmultiplier
is not less than modulier
, continue with the next modulier
fudge_constant
be a number less that modulier
, chosen randomlyfudge_constant
that are in 𝐅
This is not a very good way of generating it, but I don't see why it would ever impinge the quality of the numbers, aside from the fact that low fudge_constant
s and multiplier
are more common than a perfect generator for these might make.
Anyhow, the results are appalling:
print(birthday_compare(lcg, target), birthday_compare(control, target))
#>>> 22532 10650
print(permutations_compare(lcg, target), permutations_compare(control, target))
#>>> 17968 5820
print(runs_compare(lcg, target), runs_compare(control, target))
#>>> 8320 662
In summary, my RNG is good and a linear congruential generator is not. Considering that Java gets away with a linear congruential generator (although it only uses the lower bits), I would expect my version to be more than sufficient.
Upvotes: 12
Reputation: 70735
OK, one last try ;-) At the cost of mutating the base sequence, this takes no additional space, and requires time proportional to n
for each sample(n)
call:
class Sampler(object):
def __init__(self, base):
self.base = base
self.navail = len(base)
def sample(self, n):
from random import randrange
if n < 0:
raise ValueError("n must be >= 0")
if n > self.navail:
raise ValueError("fewer than %s unused remain" % n)
base = self.base
for _ in range(n):
i = randrange(self.navail)
self.navail -= 1
base[i], base[self.navail] = base[self.navail], base[i]
return base[self.navail : self.navail + n]
Little driver:
s = Sampler(list(range(100)))
for i in range(9):
print s.sample(10)
print s.sample(1)
print s.sample(1)
In effect, this implements a resumable random.shuffle()
, pausing after n
elements have been selected. base
is not destroyed, but is permuted.
Upvotes: 9
Reputation: 70735
Here's a way that doesn't build the difference set explicitly. But it does use a form of @Veedrac's "accept/reject" logic. If you're not willing to mutate the base sequence as you go along, I'm afraid that's unavoidable:
def sample(n, base, forbidden):
# base is iterable, forbidden is a set.
# Every element of forbidden must be in base.
# forbidden is updated.
from random import random
nusable = len(base) - len(forbidden)
assert nusable >= n
result = []
if n == 0:
return result
for elt in base:
if elt in forbidden:
continue
if nusable * random() < n:
result.append(elt)
forbidden.add(elt)
n -= 1
if n == 0:
return result
nusable -= 1
assert False, "oops!"
Here's a little driver:
base = list(range(100))
forbidden = set()
for i in range(10):
print sample(10, base, forbidden)
Upvotes: 7
Reputation: 24651
Edit: see cleaner versions below by @TimPeters and @Chronial. A minor edit pushed this to the top.
Here is what I believe is the most efficient solution for incremental sampling. Instead of a list of previously sampled numbers, the state to be maintained by the caller comprises a dictionary that is ready for use by the incremental sampler, and a count of numbers remaining in the range.
The following is a demonstrative implementation. Compared to other solutions:
O(log(number_previously_sampled))
O(number_previously_sampled)
Code:
import random
def remove (i, n, state):
if i == n - 1:
if i in state:
t = state[i]
del state[i]
return t
else:
return i
else:
if i in state:
t = state[i]
if n - 1 in state:
state[i] = state[n - 1]
del state[n - 1]
else:
state[i] = n - 1
return t
else:
if n - 1 in state:
state[i] = state[n - 1]
del state[n - 1]
else:
state[i] = n - 1
return i
s = dict()
for n in range(100, 0, -1):
print remove(random.randrange(n), n, s)
Upvotes: 5
Reputation: 97661
You can implement a shuffling generator, based off Wikipedia's "Fisher--Yates shuffle#Modern method"
def shuffle_gen(src):
""" yields random items from base without repetition. Clobbers `src`. """
for remaining in xrange(len(src), 0, -1):
i = random.randrange(remaining)
yield src[i]
src[i] = src[remaining - 1]
Which can then be sliced using itertools.islice
:
>>> import itertools
>>> sampler = shuffle_gen(range(100))
>>> sample1 = list(itertools.islice(sampler, 10))
>>> sample1
[37, 1, 51, 82, 83, 12, 31, 56, 15, 92]
>>> sample2 = list(itertools.islice(sampler, 80))
>>> sample2
[79, 66, 65, 23, 63, 14, 30, 38, 41, 3, 47, 42, 22, 11, 91, 16, 58, 20, 96, 32, 76, 55, 59, 53, 94, 88, 21, 9, 90, 75, 74, 29, 48, 28, 0, 89, 46, 70, 60, 73, 71, 72, 93, 24, 34, 26, 99, 97, 39, 17, 86, 52, 44, 40, 49, 77, 8, 61, 18, 87, 13, 78, 62, 25, 36, 7, 84, 2, 6, 81, 10, 80, 45, 57, 5, 64, 33, 95, 43, 68]
>>> sample3 = list(itertools.islice(sampler, 20))
>>> sample3
[85, 19, 54, 27, 35, 4, 98, 50, 67, 69]
Upvotes: 6
Reputation: 70833
Ok, here we go. This should be the fastest possible non-probabilistic algorithm. It has runtime of O(k⋅log²(s) + f⋅log(f)) ⊂ O(k⋅log²(f+k) + f⋅log(f)))
and space O(k+f)
. f
is the amount of forbidden numbers, s
is the length of the longest streak of forbidden numbers. The expectation for that is more complicated, but obviously bound by f
. If you assume that s^log₂(s)
is bigger than f
or are just unhappy about the fact that s
is once again probabilistic, you can change the log part to a bisection search in forbidden[pos:]
to get O(k⋅log(f+k) + f⋅log(f))
.
The actual implementation here is O(k⋅(k+f)+f⋅log(f))
, as insertion in the list forbid
is O(n)
. This is easy to fix by replacing that list with a blist sortedlist.
I also added some comments, because this algorithm is ridiculously complex. The lin
part does the same as the log
part, but needs s
instead of log²(s)
time.
import bisect
import random
def sample(k, end, forbid):
forbidden = sorted(forbid)
out = []
# remove the last block from forbidden if it touches end
for end in reversed(xrange(end+1)):
if len(forbidden) > 0 and forbidden[-1] == end:
del forbidden[-1]
else:
break
for i in xrange(k):
v = random.randrange(end - len(forbidden) + 1)
# increase v by the number of values < v
pos = bisect.bisect(forbidden, v)
v += pos
# this number might also be already taken, find the
# first free spot
##### linear
#while pos < len(forbidden) and forbidden[pos] <=v:
# pos += 1
# v += 1
##### log
while pos < len(forbidden) and forbidden[pos] <= v:
step = 2
# when this is finished, we know that:
# • forbidden[pos + step/2] <= v + step/2
# • forbidden[pos + step] > v + step
# so repeat until (checked by outer loop):
# forbidden[pos + step/2] == v + step/2
while (pos + step <= len(forbidden)) and \
(forbidden[pos + step - 1] <= v + step - 1):
step = step << 1
pos += step >> 1
v += step >> 1
if v == end:
end -= 1
else:
bisect.insort(forbidden, v)
out.append(v)
return out
Now to compare that to the “hack” (and the default implementation in python) that Veedrac proposed, which has space O(f+k)
and (n/(n-(f+k))
is the expected number of “guesses”) time:
I just plotted this for k=10
and a reasonably big n=10000
(it only gets more extreme for bigger n
). And I have to say: I only implemented this because it seemed like a fun challenge, but even I am surprised by how extreme this is:
Let’s zoom in to see what’s going on:
Yes – the guesses are even faster for the 9998th number you generate. Note that, as you can see in the first plot, even my one-liner is probably faster for bigger f/n
(but still has rather horrible space requirements for big n
).
To drive the point home: The only thing you are spending time on here is generating the set, as that’s the f
factor in Veedrac’s method.
So I hope my time here was not wasted and I managed to convince you that Veedrac’s method is simply the way to go. I can kind of understand why that probabilistic part troubles you, but maybe think of the fact that hashmaps (= python dict
s) and tons of other algorithms work with similar methods and they seem to be doing just fine.
You might be afraid of the variance in the number of repetitions. As noted above, this follows a geometric distribution with p=n-f/n
. So the standard deviation (=the amount you “should expect” the result to deviate from the expected average) is
Which is basically the same as the mean (√f⋅n < √n² = n
).
****edit**:
I just realized that s
is actually also n/(n-(f+k))
. So a more exact runtime for my algorithm is O(k⋅log²(n/(n-(f+k))) + f⋅log(f))
. Which is nice since given the graphs above, it proves my intuition that that is quite a bit faster than O(k⋅log(f+k) + f⋅log(f))
. But rest assured that that also does not change anything about the results above, as the f⋅log(f)
is the absolutely dominant part in the runtime.
Upvotes: 10
Reputation: 70735
If you know in advance that you're going to want to multiple samples without overlaps, easiest is to do random.shuffle()
on list(range(100))
(Python 3 - can skip the list()
in Python 2), then peel off slices as needed.
s = list(range(100))
random.shuffle(s)
first_sample = s[-10:]
del s[-10:]
second_sample = s[-10:]
del s[-10:]
# etc
Else @Chronial's answer is reasonably efficient.
Upvotes: 28
Reputation: 70833
Reasonably fast one-liner (O(n + m)
, n=range,m=old samplesize):
next_sample = random.sample(set(range(100)).difference(my_sample), 10)
Upvotes: 5