Reputation: 367
I want to draw N random samples from a distribution that is the sum of two truncated normal distributions. I get what I want by subclassing rv_continuous
class from scipy.stats
and providing a pdf that is the mean of the two given pdfs:
import numpy as np
from scipy import stats
my_lim = [0.05, 7] # lower and upper limit
my_loc = [1.2, 3] # loc values of the two truncated normal distributions
my_scale = [0.6, 2] # scale values of the two truncated normal distributions
class sum_truncnorm(stats.rv_continuous):
def _pdf(self, x):
return (stats.truncnorm.pdf(x,
a=(my_lim[0] - my_loc[0]) / my_scale[0],
b=(my_lim[1] - my_loc[0]) / my_scale[0],
loc=my_loc[0],
scale=my_scale[0]) +
stats.truncnorm.pdf(x,
a=(my_lim[0] - my_loc[1]) / my_scale[1],
b=(my_lim[1] - my_loc[1]) / my_scale[1],
loc=my_loc[1],
scale=my_scale[1]) / 2
However, using:
my_dist = sum_truncnorm()
my_rvs = my_dist.rvs(size=10)
is very slow and takes about 5 seconds per random value.
I'm sure that this can be done much faster, but I am not sure how to do it. Should I maybe define my distribution as a sum of (non truncated) normal distributions and force the truncated afterwards? I did some tests in this direction, but this was only about 10x faster and thus still way to slow.
Google told me that I probably need to use inverse transform sampling and override the _rvs
method, but I failed to make this working for my truncated distributions.
Upvotes: 1
Views: 225
Reputation: 26030
First, you'll going to have to make sure _pdf is normalized. The framework does not check it, and silently produces nonsense otherwise.
Second, to make drawing variates fast, you need to implement a _ppf or _rvs. With just _pdf only, it goes through the generic code path (numeric integration and root-finding) which why your current version is slow.
Upvotes: 1