Reputation: 1748
We need to do a logistic regression in Java. We used this code in Python http://blog.smellthedata.com/2009/06/python-logistic-regression-with-l2.html and basically want the same thing in Java. I was directed to Weka, but the license is non-commercial.
I found the Omegahat API has the BFGS minimizer like Scipy, but I can't figure out the API: http://www.omegahat.org/api/org/omegahat/Numerics/Optimizers/OptimizerAlgorithmBFGS.html I want to implement a class with the model and put in the likelihood functions. But the model.eval takes a ModelPointNumeric which also has an eval. In any case, it is not clearly correlated with the math as the python code using numpy is. Is the omegahat api used or maintained? I could not find a mailing list for it.
Upvotes: 8
Views: 28728
Reputation: 2428
As mentioned you can use Apache Commons Math to fit a logistic curve to data. The Logistic
function from apache math is more generalized than the standard logistic function. It has 6 parameters (k,m,b,q,a,n
) whereas the standard logistic function has 3 (k,m,b
) however, if q = 0.0
, a = 0.0
, and q = 1.0
then the generalized function simplifies to the 3 parameter function. The values of q, a, and n
mainly affect the offset and symmetry of the curve.
The example code shows how to fit the function to data generated using the standard function. The fit is trivial in this case (as the data is from the logistic equation). If you wish you can play with the xvalues
and yvalues
to introduce noise or to distort the curve to give a more realistic scenario.
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.analysis.function.Logistic;
import org.apache.commons.math3.fitting.SimpleCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoint;
public class LogisticFit {
public static void main(String[] args) {
double[] xvalues = new double[]{-6.0,-5.0,-4.0,-3.0,-2.0,-1.0,0.0,1.0,2.0,3.0,4.0,5.0,6.0};
/*
* These are Y values for values of x for a "standard" logistic equation
*/
double[] yvalues = new double[]{0.002472623, 0.006692851, 0.01798621, 0.047425873, 0.119202922, 0.268941421,
0.5, 0.731058579, 0.880797078, 0.952574127, 0.98201379, 0.993307149, 0.997527377};
List<WeightedObservedPoint> weightedObservedPoints = new ArrayList<>();
for (int x = 0; x < yvalues.length; x++) {
weightedObservedPoints.add( new WeightedObservedPoint(1.0, xvalues[x], yvalues[x]));
}
/* Starting Estimates */
/* Lower asymptote. A reasonable estimate is the minimum observed value*/
double a = yvalues[0]; //assumes observations are sorted
/* Upper asymptote. The 'carrying capacity'. A reasonable estimate is the maximum observed value */
double k = yvalues[ yvalues.length -1 ];
/* Growth rate. For a standard logistic curve this is 1 so 1 is a reasonable estimate*/
double b = 1.0;
/* Parameter that affects near which asymptote maximum growth occurs. 1.0 if we assume the curve is symmetric*/
double n = 1.0;
/* Parameter that affects the position of the curve along the ordinate axis. */
double q = 1.0;
/*
* Abscissa of maximum growth. The x value where inflection point of the curve occurs. The value of x when the
* population is halfway to the maximum. A reasonable estimate is halfway along the x axis if we assume symmetry
* */
double m = xvalues[ xvalues.length / 2];
double[] estimates = new double[]{k, m, b, q, a, n};
/* the logistic function we want to fit */
ParametricUnivariateFunction logisticFunction = new Logistic.Parametric();
SimpleCurveFitter curveFitter = SimpleCurveFitter.create(logisticFunction, estimates);
final double[] fit = curveFitter.fit(weightedObservedPoints);
System.out.println("estimated k = "+ fit[0] + ", True value = 1.0");
System.out.println("estimated m = "+ fit[1] + ", True value = 0.0");
System.out.println("estimated b = "+ fit[2] + ", True value = 1.0");
System.out.println("estimated q = "+ fit[3] + ", True value = 1.0");
System.out.println("estimated a = "+ fit[4] + ", True value = 0.0");
System.out.println("estimated n = "+ fit[5] + ", True value = 1.0");
System.out.println("value of y at estimated curve inflection point (m) = "+logisticFunction.value(m, fit));
}
}
The output will be similar to:
estimated k = 0.9999999999617879, True value = 1.0
estimated m = 0.05131427607556755, True value = 0.0
estimated b = 1.0000000013063237, True value = 1.0
estimated q = 0.949980068678136, True value = 1.0
estimated a = 4.555478390914705E-12, True value = 0.0
estimated n = 1.0000000008645784, True value = 1.0
value of y at estimated curve inflection point (m) = 0.4999999999837729
Upvotes: 2
Reputation: 3097
There is a simple implementation for Java on GitHub, in just 2 classes (plus one utility). It's probably not optimized but there are no dependencies to download.
I created a pull request which simplifies it further to a single file.
Upvotes: 2
Reputation: 1748
Thanks for the inputs. After much searching I found this: http://mallet.cs.umass.edu/optimization.php This is almost a 1:1 translation of how the numpy implementation works, it allows us to do logistic regression ourselves with the mathematical formulas. So I can take our python class and implement the 4-5 methods necessary and then pass it to the BFGS solver to perform our logistic regression.
It worked great, the only thing we had to realize was that Mallet maximizes the function and Numpy has a minimizer.
Upvotes: 10
Reputation: 28492
Weka has commercial version of a license, see this page for details.
However, if logistic regression is the only data mining technique you need, take a look at LIBLINEAR, which is distributed under BSD license.
Upvotes: 4
Reputation: 47665
If you do not find anything else, take a look at Apache Commons Math: it is a library of lightweight, self-contained mathematics and statistics components addressing the most common problems not available in the Java programming language or Commons Lang.
Good luck.
Upvotes: 2