alessandro308
alessandro308

Reputation: 2162

How to sample from a truncated Gaussian distribution without using SciPy?

SciPy is a huge library. It is quite embarrassing that, for using a simple feature, i.e. compute the truncated distribution, I have to install (and import) 23 MB of code.

There is some solution to achieve this problem in a simpler way?

Upvotes: 0

Views: 1288

Answers (1)

clemisch
clemisch

Reputation: 1034

You can implement it manually via inverse transform sampling. You basically compute the inverse of the cumulative distribution function at values taken from an uniform distribution between 0 and 1.

import numpy as np

def normal(x, mu, sig):
    return 1. / (np.sqrt(2 * np.pi) * sig) * np.exp(-0.5 * np.square(x - mu) / np.square(sig))


def trunc_normal(x, mu, sig, bounds=None):
    if bounds is None: 
        bounds = (-np.inf, np.inf)

    norm = normal(x, mu, sig)
    norm[x < bounds[0]] = 0
    norm[x > bounds[1]] = 0

    return norm


def sample_trunc(n, mu, sig, bounds=None):
    """ Sample `n` points from truncated normal distribution """
    x = np.linspace(mu - 5. * sig, mu + 5. * sig, 10000)
    y = trunc_normal(x, mu, sig, bounds)
    y_cum = np.cumsum(y) / y.sum()

    yrand = np.random.rand(n)
    sample = np.interp(yrand, y_cum, x)

    return sample


# Example
import matplotlib.pyplot as plt
samples = sample_trunc(10000, 0, 1, (-1, 1))
plt.hist(samples, bins=100)
plt.show()

Upvotes: 2

Related Questions