Reputation: 2162
SciPy is a huge library. It is quite embarrassing that, for using a simple feature, i.e. compute the truncated distribution, I have to install (and import) 23 MB of code.
There is some solution to achieve this problem in a simpler way?
Upvotes: 0
Views: 1288
Reputation: 1034
You can implement it manually via inverse transform sampling. You basically compute the inverse of the cumulative distribution function at values taken from an uniform distribution between 0 and 1.
import numpy as np
def normal(x, mu, sig):
return 1. / (np.sqrt(2 * np.pi) * sig) * np.exp(-0.5 * np.square(x - mu) / np.square(sig))
def trunc_normal(x, mu, sig, bounds=None):
if bounds is None:
bounds = (-np.inf, np.inf)
norm = normal(x, mu, sig)
norm[x < bounds[0]] = 0
norm[x > bounds[1]] = 0
return norm
def sample_trunc(n, mu, sig, bounds=None):
""" Sample `n` points from truncated normal distribution """
x = np.linspace(mu - 5. * sig, mu + 5. * sig, 10000)
y = trunc_normal(x, mu, sig, bounds)
y_cum = np.cumsum(y) / y.sum()
yrand = np.random.rand(n)
sample = np.interp(yrand, y_cum, x)
return sample
# Example
import matplotlib.pyplot as plt
samples = sample_trunc(10000, 0, 1, (-1, 1))
plt.hist(samples, bins=100)
plt.show()
Upvotes: 2