Reputation: 1568
I am fairly new to curve_fit
with scipy
. I have many distributions that look like y
and do not look like y
. Most of the distributions that look like y
are beta distributions. My approach is that if I can fit the beta function on all of my unique IDs that have varying distributions, I can find the coefficients from the beta function, then look at coefficients that are close in magnitude, then I can effectively filter out all distributions that look like y
.
y
looks like this (same data in example code below):
However, I am having some trouble getting started.
y = array([[ 0.50423378, 0.50423378, 0.50423378, 0.50254455, 0.50423378, 0.50254455, 0.50423378, 0.50507627, 0.50507627, 0.50423378,0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378, 0.50423378, 0.50423378, 0.50423378, 0.50254455, 0.50254455, 0.50254455, 0.50423378, 0.50423378, 0.50507627, 0.50507627,0.50507627, 0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378, 0.50507627, 0.50507627, 0.50423378, 0.50507627, 0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378,0.50423378, 0.50423378, 0.50254455, 0.50254455, 0.5, 0.50254455, 0.50254455, 0.50254455, 0.50423378, 0.50423378,0.50423378, 0.50423378, 0.50423378, 0.50254455, 0.50423378, 0.50254455, 0.50254455, 0.50423378, 0.50423378, 0.50254455,0.5 , 0.5 , 0.50254455, 0.50254455, 0.5 ,0.49658699, 0.49228746, 0.49228746, 0.48707792, 0.48092881,0.48707792, 0.48092881, 0.48092881, 0.48092881, 0.48092881,0.48092881, 0.48092881, 0.47380354, 0.47380354, 0.48092881,0.48707792, 0.48707792, 0.48092881, 0.48092881, 0.48092881,0.48092881, 0.48092881, 0.48092881, 0.47380354, 0.48092881,0.48092881, 0.48092881, 0.48707792, 0.48707792, 0.48707792,0.49228746, 0.49228746, 0.49228746, 0.49228746, 0.48707792,0.48707792, 0.48707792, 0.49228746, 0.48707792, 0.48707792,0.48707792, 0.48707792, 0.48707792, 0.49228746, 0.49228746,0.48707792, 0.48707792, 0.49228746, 0.49658699, 0.49658699,0.49658699, 0.49228746, 0.49228746, 0.49658699, 0.49228746,0.49658699, 0.5 , 0.50254455, 0.50423378, 0.50423378,0.50254455, 0.50423378, 0.50423378, 0.50254455, 0.5 ,0.5 , 0.5 , 0.5 , 0.5 , 0.50254455,0.50254455, 0.5 , 0.50254455, 0.5 , 0.5 ,0.5 , 0.5 , 0.5 , 0.5 , 0.49658699,0.49228746, 0.48707792, 0.48707792, 0.48707792, 0.49228746,0.49228746, 0.48707792, 0.48707792, 0.49228746, 0.48707792,0.48707792, 0.48707792, 0.48092881, 0.48092881, 0.48707792,0.48707792, 0.48092881, 0.47380354, 0.48092881, 0.48092881,0.48707792, 0.49228746, 0.48707792, 0.49228746, 0.48707792,0.48092881, 0.47380354, 0.46565731, 0.46565731, 0.46565731,0.45643546, 0.45643546, 0.45643546, 0.45643546, 0.45643546,0.45643546, 0.45643546, 0.46565731, 0.45643546, 0.45643546,0.45643546, 0.44607129, 0.45643546, 0.45643546, 0.45643546,0.44607129, 0.44607129, 0.43448304, 0.43448304, 0.43448304,0.44607129, 0.45643546, 0.45643546, 0.45643546, 0.46565731,0.47380354, 0.48092881, 0.48092881, 29.38186886, 29.38186886,29.38186886, 29.37898909, 29.45299206, 29.52449116, 29.74083063,29.73771398, 29.73771398, 29.74083063, 29.74083063, 29.74083063,29.74083063, 29.73771398, 29.74083063, 29.73771398, 29.73771398,29.73771398, 29.73771398, 29.74083063, 29.74083063, 29.74083063,30.12527698, 30.48367189, 30.8169243 , 30.8169243 , 30.8169243 ,30.8169243 , 30.82153203, 30.8169243 , 30.81230208, 30.81230208,30.80766536, 30.81230208, 30.81230208, 30.80766536, 30.80301414,30.80301414, 30.80301414, 30.80301414, 30.80301414, 30.80766536,30.81230208, 30.81230208, 30.81230208, 30.81230208, 30.8169243 ,30.82153203, 30.82612528, 10.51949923, 10.51949923, 10.51436497,10.51436497, 10.22456193, 9.91464422, 9.36922158, 9.37416663,9.36922158, 9.36922158, 9.36922158, 9.37416663, 9.37906375,9.383913 , 9.383913 , 9.38871446, 9.383913 , 9.37906375,9.37416663, 9.36922158, 9.36422851, 9.35918734, 7.72711675,5.53121937, 0.5 , 0.50254455, 0.50254455, 0.50254455,0.50254455, 0.50254455, 0.5 , 0.5 , 0.49658699,0.5 , 0.5 , 0.5 , 0.49658699, 0.49658699,0.5 , 0.50254455, 0.50423378, 0.50423378, 0.50423378,0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378,0.50423378, 0.50423378, 0.50254455, 0.50254455, 0.5 ,0.5 , 0.5 , 0.49658699, 0.5 , 0.49658699,0.49658699, 0.49658699, 0.49658699, 0.49658699, 0.49658699,0.49658699, 0.49658699, 0.49228746, 0.48707792, 0.48707792,0.48092881, 0.47380354, 0.47380354, 0.46565731, 0.46565731,0.47380354, 0.46565731, 0.47380354, 0.47380354, 0.47380354, 0.47380354, 0.48092881]])
Using this example from scipy, how do I get the x
array and plug this in to get my coefficients, then plot the curve_fit
on my distribution?
import numpy as np
from scipy.optimize import curve_fit
from scipy.special import gamma as gamma
def betafunc(x,a,b,cst):
return cst*gamma(a+b) * (x**(a-1)) * ((1-x)**(b-1)) / ( gamma(a)*gamma(b) )
x = np.array( [0.1, 0.3, 0.5, 0.7, 0.9, 1.1])
y = np.array( [0.45112234, 0.56934313, 0.3996803 , 0.28982859, 0.19682153, 0.] )
popt2,pcov2 = curve_fit(betafunc,x[:-1],y[:-1],p0=(0.5,1.5,0.5))
print(popt2)
print(pcov2)
Upvotes: 3
Views: 1856
Reputation: 331
For the first part of your question: If you have a set of observations you can use numpy.histogram to get the histogram. To get the center of each bin proceed as in my code below. Those values you can use for the fitting process. To the data provided by you, whoever one cannot fit a betafunc because it simply doesn't fit.
import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit
from scipy.special import gamma as gamma
def betafunc(x,a,b,cst):
return cst*gamma(a+b) * (x**(a-1)) * ((1-x)**(b-1)) / ( gamma(a)*gamma(b) )
y_data=np.array([[ 0.50423378, 0.50423378, 0.50423378, 0.50254455, 0.50423378, 0.50254455, 0.50423378, 0.50507627, 0.50507627, 0.50423378,0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378, 0.50423378, 0.50423378, 0.50423378, 0.50254455, 0.50254455, 0.50254455, 0.50423378, 0.50423378, 0.50507627, 0.50507627,0.50507627, 0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378, 0.50507627, 0.50507627, 0.50423378, 0.50507627, 0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378,0.50423378, 0.50423378, 0.50254455, 0.50254455, 0.5, 0.50254455, 0.50254455, 0.50254455, 0.50423378, 0.50423378,0.50423378, 0.50423378, 0.50423378, 0.50254455, 0.50423378, 0.50254455, 0.50254455, 0.50423378, 0.50423378, 0.50254455,0.5 , 0.5 , 0.50254455, 0.50254455, 0.5 ,0.49658699, 0.49228746, 0.49228746, 0.48707792, 0.48092881,0.48707792, 0.48092881, 0.48092881, 0.48092881, 0.48092881,0.48092881, 0.48092881, 0.47380354, 0.47380354, 0.48092881,0.48707792, 0.48707792, 0.48092881, 0.48092881, 0.48092881,0.48092881, 0.48092881, 0.48092881, 0.47380354, 0.48092881,0.48092881, 0.48092881, 0.48707792, 0.48707792, 0.48707792,0.49228746, 0.49228746, 0.49228746, 0.49228746, 0.48707792,0.48707792, 0.48707792, 0.49228746, 0.48707792, 0.48707792,0.48707792, 0.48707792, 0.48707792, 0.49228746, 0.49228746,0.48707792, 0.48707792, 0.49228746, 0.49658699, 0.49658699,0.49658699, 0.49228746, 0.49228746, 0.49658699, 0.49228746,0.49658699, 0.5 , 0.50254455, 0.50423378, 0.50423378,0.50254455, 0.50423378, 0.50423378, 0.50254455, 0.5 ,0.5 , 0.5 , 0.5 , 0.5 , 0.50254455,0.50254455, 0.5 , 0.50254455, 0.5 , 0.5 ,0.5 , 0.5 , 0.5 , 0.5 , 0.49658699,0.49228746, 0.48707792, 0.48707792, 0.48707792, 0.49228746,0.49228746, 0.48707792, 0.48707792, 0.49228746, 0.48707792,0.48707792, 0.48707792, 0.48092881, 0.48092881, 0.48707792,0.48707792, 0.48092881, 0.47380354, 0.48092881, 0.48092881,0.48707792, 0.49228746, 0.48707792, 0.49228746, 0.48707792,0.48092881, 0.47380354, 0.46565731, 0.46565731, 0.46565731,0.45643546, 0.45643546, 0.45643546, 0.45643546, 0.45643546,0.45643546, 0.45643546, 0.46565731, 0.45643546, 0.45643546,0.45643546, 0.44607129, 0.45643546, 0.45643546, 0.45643546,0.44607129, 0.44607129, 0.43448304, 0.43448304, 0.43448304,0.44607129, 0.45643546, 0.45643546, 0.45643546, 0.46565731,0.47380354, 0.48092881, 0.48092881, 29.38186886, 29.38186886,29.38186886, 29.37898909, 29.45299206, 29.52449116, 29.74083063,29.73771398, 29.73771398, 29.74083063, 29.74083063, 29.74083063,29.74083063, 29.73771398, 29.74083063, 29.73771398, 29.73771398,29.73771398, 29.73771398, 29.74083063, 29.74083063, 29.74083063,30.12527698, 30.48367189, 30.8169243 , 30.8169243 , 30.8169243 ,30.8169243 , 30.82153203, 30.8169243 , 30.81230208, 30.81230208,30.80766536, 30.81230208, 30.81230208, 30.80766536, 30.80301414,30.80301414, 30.80301414, 30.80301414, 30.80301414, 30.80766536,30.81230208, 30.81230208, 30.81230208, 30.81230208, 30.8169243 ,30.82153203, 30.82612528, 10.51949923, 10.51949923, 10.51436497,10.51436497, 10.22456193, 9.91464422, 9.36922158, 9.37416663,9.36922158, 9.36922158, 9.36922158, 9.37416663, 9.37906375,9.383913 , 9.383913 , 9.38871446, 9.383913 , 9.37906375,9.37416663, 9.36922158, 9.36422851, 9.35918734, 7.72711675,5.53121937, 0.5 , 0.50254455, 0.50254455, 0.50254455,0.50254455, 0.50254455, 0.5 , 0.5 , 0.49658699,0.5 , 0.5 , 0.5 , 0.49658699, 0.49658699,0.5 , 0.50254455, 0.50423378, 0.50423378, 0.50423378,0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378,0.50423378, 0.50423378, 0.50254455, 0.50254455, 0.5 ,0.5 , 0.5 , 0.49658699, 0.5 , 0.49658699,0.49658699, 0.49658699, 0.49658699, 0.49658699, 0.49658699,0.49658699, 0.49658699, 0.49228746, 0.48707792, 0.48707792,0.48092881, 0.47380354, 0.47380354, 0.46565731, 0.46565731,0.47380354, 0.46565731, 0.47380354, 0.47380354, 0.47380354, 0.47380354, 0.48092881]])
hist=np.histogram(y_data[0],bins=20)
x=(hist[1][1:]+hist[1][:-1])/2
y=hist[0]
print(x,y)
plt.step(x,y,label='Manual calculation of the center of the bins')
plt.hist(y_data[0],bins=20,histtype='bar',label='Automatic plot with plt.hist')
plt.legend()
plt.show()
popt2,pcov2 = curve_fit(betafunc,x[:-1],y[:-1],p0=(0.5,1.5,0.5))
For the second part of your question: To plot the function with the optimal fit parameters you only need to add the four last lines of code which I added at the end.
import numpy as np
from scipy.optimize import curve_fit
from scipy.special import gamma as gamma
def betafunc(x,a,b,cst):
return cst*gamma(a+b) * (x**(a-1)) * ((1-x)**(b-1)) / ( gamma(a)*gamma(b) )
x = np.array( [0.1, 0.3, 0.5, 0.7, 0.9, 1.1])
y = np.array( [0.45112234, 0.56934313, 0.3996803 , 0.28982859, 0.19682153, 0.] )
popt2,pcov2 = curve_fit(betafunc,x[:-1],y[:-1],p0=(0.5,1.5,0.5))
print(popt2)
print(pcov2)
from matplotlib import pyplot as plt
plt.plot(x,betafunc(x,*popt2))
plt.plot(x,y)
plt.show()
Upvotes: 2
Reputation: 101
If you are not constrained to use curve_fit
. I would suggest that you take a look at scipy.stats.beta. One possible solution is:
from scipy.stats import beta
y = array([[ 0.50423378, 0.50423378, 0.50423378, 0.50254455, 0.50423378, 0.50254455, 0.50423378, 0.50507627, 0.50507627, 0.50423378,0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378, 0.50423378, 0.50423378, 0.50423378, 0.50254455, 0.50254455, 0.50254455, 0.50423378, 0.50423378, 0.50507627, 0.50507627,0.50507627, 0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378, 0.50507627, 0.50507627, 0.50423378, 0.50507627, 0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378,0.50423378, 0.50423378, 0.50254455, 0.50254455, 0.5, 0.50254455, 0.50254455, 0.50254455, 0.50423378, 0.50423378,0.50423378, 0.50423378, 0.50423378, 0.50254455, 0.50423378, 0.50254455, 0.50254455, 0.50423378, 0.50423378, 0.50254455,0.5 , 0.5 , 0.50254455, 0.50254455, 0.5 ,0.49658699, 0.49228746, 0.49228746, 0.48707792, 0.48092881,0.48707792, 0.48092881, 0.48092881, 0.48092881, 0.48092881,0.48092881, 0.48092881, 0.47380354, 0.47380354, 0.48092881,0.48707792, 0.48707792, 0.48092881, 0.48092881, 0.48092881,0.48092881, 0.48092881, 0.48092881, 0.47380354, 0.48092881,0.48092881, 0.48092881, 0.48707792, 0.48707792, 0.48707792,0.49228746, 0.49228746, 0.49228746, 0.49228746, 0.48707792,0.48707792, 0.48707792, 0.49228746, 0.48707792, 0.48707792,0.48707792, 0.48707792, 0.48707792, 0.49228746, 0.49228746,0.48707792, 0.48707792, 0.49228746, 0.49658699, 0.49658699,0.49658699, 0.49228746, 0.49228746, 0.49658699, 0.49228746,0.49658699, 0.5 , 0.50254455, 0.50423378, 0.50423378,0.50254455, 0.50423378, 0.50423378, 0.50254455, 0.5 ,0.5 , 0.5 , 0.5 , 0.5 , 0.50254455,0.50254455, 0.5 , 0.50254455, 0.5 , 0.5 ,0.5 , 0.5 , 0.5 , 0.5 , 0.49658699,0.49228746, 0.48707792, 0.48707792, 0.48707792, 0.49228746,0.49228746, 0.48707792, 0.48707792, 0.49228746, 0.48707792,0.48707792, 0.48707792, 0.48092881, 0.48092881, 0.48707792,0.48707792, 0.48092881, 0.47380354, 0.48092881, 0.48092881,0.48707792, 0.49228746, 0.48707792, 0.49228746, 0.48707792,0.48092881, 0.47380354, 0.46565731, 0.46565731, 0.46565731,0.45643546, 0.45643546, 0.45643546, 0.45643546, 0.45643546,0.45643546, 0.45643546, 0.46565731, 0.45643546, 0.45643546,0.45643546, 0.44607129, 0.45643546, 0.45643546, 0.45643546,0.44607129, 0.44607129, 0.43448304, 0.43448304, 0.43448304,0.44607129, 0.45643546, 0.45643546, 0.45643546, 0.46565731,0.47380354, 0.48092881, 0.48092881, 29.38186886, 29.38186886,29.38186886, 29.37898909, 29.45299206, 29.52449116, 29.74083063,29.73771398, 29.73771398, 29.74083063, 29.74083063, 29.74083063,29.74083063, 29.73771398, 29.74083063, 29.73771398, 29.73771398,29.73771398, 29.73771398, 29.74083063, 29.74083063, 29.74083063,30.12527698, 30.48367189, 30.8169243 , 30.8169243 , 30.8169243 ,30.8169243 , 30.82153203, 30.8169243 , 30.81230208, 30.81230208,30.80766536, 30.81230208, 30.81230208, 30.80766536, 30.80301414,30.80301414, 30.80301414, 30.80301414, 30.80301414, 30.80766536,30.81230208, 30.81230208, 30.81230208, 30.81230208, 30.8169243 ,30.82153203, 30.82612528, 10.51949923, 10.51949923, 10.51436497,10.51436497, 10.22456193, 9.91464422, 9.36922158, 9.37416663,9.36922158, 9.36922158, 9.36922158, 9.37416663, 9.37906375,9.383913 , 9.383913 , 9.38871446, 9.383913 , 9.37906375,9.37416663, 9.36922158, 9.36422851, 9.35918734, 7.72711675,5.53121937, 0.5 , 0.50254455, 0.50254455, 0.50254455,0.50254455, 0.50254455, 0.5 , 0.5 , 0.49658699,0.5 , 0.5 , 0.5 , 0.49658699, 0.49658699,0.5 , 0.50254455, 0.50423378, 0.50423378, 0.50423378,0.50507627, 0.50507627, 0.50423378, 0.50423378, 0.50423378,0.50423378, 0.50423378, 0.50254455, 0.50254455, 0.5 ,0.5 , 0.5 , 0.49658699, 0.5 , 0.49658699,0.49658699, 0.49658699, 0.49658699, 0.49658699, 0.49658699,0.49658699, 0.49658699, 0.49228746, 0.48707792, 0.48707792,0.48092881, 0.47380354, 0.47380354, 0.46565731, 0.46565731,0.47380354, 0.46565731, 0.47380354, 0.47380354, 0.47380354, 0.47380354, 0.48092881]])
params = beta.fit(y)
y2 = np.loadtxt("other_data_file.dat") # other distribution file
params2 = beta.fit(y2)
you can then compare the parameters individually by comparing params
and params2
. Please note that scipy.stats.beta
uses the standardized form while defining the probability density function.
Upvotes: 0