Reputation: 3525
I want to sample from a list until all elements have appeared at least once. We can use tossing a die as an example. A die has six sides: 1 through 6. I keep tossing it until I see all six values at least once, then I stop. Here is my function.
import numpy as np
def sample_until_all(all_values):
sampled_vals = np.empty(0)
while True:
cur_val = np.random.choice(all_values, 1)
sampled_vals = np.append(sampled_vals, cur_val[0])
if set(all_values) == set(sampled_vals):
return(len(sampled_vals))
sample_until_all(range(6))
cur_val
is the value from the current toss. I keep all sampled values in sampled_vals
using np.append
, and I check if it contains all possible values after each toss using set(all_values) == set(sampled_vals)
. It works but not efficiently (I believe). Any ideas how to make it faster? Thanks.
I just use this as a toy example. The actual list I need is much larger than just 6 values.
Upvotes: 1
Views: 103
Reputation: 2967
The following creates a set a
, which will be a collection of the unique objects in all_values
. This set will represent the elements that we have not yet seen. As we randomly choose elements from all_values
, if it is one we have not seen before, then we remove the corresponding object from the set. We continue doing this until the set a
is empty.
from random import choice
def sample_until_all(all_values):
a = set(all_values)
count = 0
while a:
r = choice(all_values)
print(r)
count += 1
if r in a:
a.remove(r)
print(f"\nIt took {count} draws.")
return count
Example session (calling sample_until_all(range(6))
)
1
5
1
3
5
2
4
0
It took 8 draws.
Timings
Obtained with perfplot.show()
with setup = lambda n: range(n)
and n_range = [2**k for k in range(11)]
. I removed the print
statements from my function before doing the timings.
Upvotes: 1
Reputation: 199
Numpy arrays work great when they have a definite size, but not so much when that's not the case.
The core part of your program is checking for the existence of elements in sampled_vals
, and that is a task at which dict
s excel. Converting an array to a dict in every loop, however, is unnecessarily costly. We can thus simplify your code as such:
from random import choice
def new_sample_until_all(all_values):
sampled_vals = set()
universe_set = set(all_values)
n = 0
while sampled_vals != universe_set:
cur_val = choice(all_values)
sampled_vals.add(cur_val)
n += 1
return n
whose key improvements are
sampled_vals
as a set, not as a variable-length numpy array which had to be converted every time.n
counter to keep track of the amount of times you had to sample from all_values
.In a simple test on my machine, I get, with values = list(range(10**3))
:
%timeit sample_until_all(values)
>>> 2.6 s ± 1.22 s
%timeit new_sample_until_all(values)
>>> 2.37 ms ± 11.8 µs
which means the new code is roughly 1000 times faster. Not bad!
Upvotes: 1
Reputation: 26915
Here's another approach that doesn't rely on building and modifying lists (except initially to get a list from the specified range):
import random
def sample_until_all(all_values):
vlist = list(all_values) # list of values to choose from
target = sum(vlist) # target sum
seen = set() # values seen so far
total = 0 # running total
while total != target:
choice = random.choice(vlist)
if not choice in seen:
seen.add(choice)
total += choice
sample_until_all(range(6))
Upvotes: 1
Reputation: 260335
I don't think numpy is really useful here as you generate your values one by one.
What about using a Counter?
from collections import Counter
import random
def sample_until_all(all_values):
all_values = list(all_values)
c = Counter()
while len(c)<len(all_values):
c[random.choice(all_values)] +=1
return sum(c.values())
sample_until_all(range(6))
This code is about 50 times faster for range(6)
, and the difference is even greater when the range is larger
Upvotes: 2