Reputation: 47
I need to generate a random list of 8 integers between [0, 4] with weights and the sum should be 12.
Something like this:
from random import choices
while True:
lst = choices(population=[0, 1, 2, 3, 4], weights=[0.20, 0.30, 0.30, 0.15, 0.05], k=8)
if sum(lst) == 12:
print(lst)
break
There is a smarter way to do that?
Upvotes: 0
Views: 83
Reputation: 16184
Severin's solution is simple and should to be quick for a large part of the parameter space, but might get slow out at the edges of the distribution. For example, generating 30 values that have to sum to 100 will probably take a few seconds until it randomly stumbles on a valid solution
The following code makes sure it only samples from values that are valid, so has a more deterministic runtime:
def sample_values(population, k, total, *, weights=None):
if weights is None:
# weights not probabilities, so no need to sum to 1
weights = [1] * len(population)
# ensure population is a sorted list, with weights in consistant order
population, weights = zip(*sorted(zip(population, weights)))
population = list(population)
weights = list(weights)
result = []
for _ in range(k):
# population values that would take us past the running total should be excluded
while population[-1] > total:
del population[-1]
del weights[-1]
# maintain k as the number of remaining items
k -= 1
# remove anything where just using it and then maximal values wouldn't get us to the total
remain_lim = total - max(population) * k
while population[0] < remain_lim:
del population[0]
del weights[0]
# sample next value
n, = choices(population, weights)
result.append(n)
# maintain total as the remaining total
total -= n
return result
The first call to tee
really wants strict=True
from Python 3.10, but have left that out as I presume you're not using that yet
The above can, e.g., be used as:
sample_values(range(5), 8, 12, weights=[0.20, 0.30, 0.30, 0.15, 0.05])
and runs in ~12µs which is comparable to Severin's multinomial
solution which takes ~18µs for these parameters.
Upvotes: 1
Reputation: 20080
You could sample from multinomial which automatically gets the right sum, and reject out-of-range values
Along the lines, Python 3.9.1, Windows 10 x64
import numpy as np
rng = np.random.default_rng()
def smpl(rng):
while True:
q = rng.multinomial(12, [1./8.]*8)
if np.any(q > 4):
continue
return q
Upvotes: 0