Santi Peñate-Vera
Santi Peñate-Vera

Reputation: 1186

Solve least squares regression in java

I am implementing a pandas DataFrame clone in Java and one of the functionalities I need is resampling. I have found a nice method to do it here. At the referenced link they use python, specifically the lstsq function from numpy that takes a matrix A and a vector b exactly like in the formulas I need to implement.

Now I go to the Apache Commons math website on Least squares and the API seems nothing like Least_squares(A, b) but something completely different, and very complicated by the way.

I'd like to know how to solve the least squares non linear regression in java only by passing a matrix A and a vector b like in python.

Upvotes: 4

Views: 4994

Answers (1)

Manos Nikolaidis
Manos Nikolaidis

Reputation: 22244

The Least Squares package in Apache Commons uses numeric minimization algorithms like Gauss-Newton and Levenberg-Marquardt for non-linear curve fitting (non-linear least squares).

numpy.linalg.lstsq on the other hand is for line fitting (linear least squares). The equivalent of lstsq in Apache commons is SimpleRegression.

In both cases you have a line fitting problem y = mx + c, where x and y are known vectors of equal length containing data points (multiple pairs of x.y scalar values). With lstsq you have to transform the problem to y = Ap, where A = [[x 1]] and p = [[m], [c]]. With SimpleRegression one option is to concatenate x and y in double[][] matrix with 2 columns and 1 data point per row.

Here is the same example mentioned in lstsq docs written for SimpleRegression:

import org.apache.commons.math3.stat.regression.SimpleRegression;

public class StackOverflow {
    public static void main(String[] args) {
        // creating regression object, passing true to have intercept term
        SimpleRegression simpleRegression = new SimpleRegression(true);

        // passing data to the model
        // model will be fitted automatically by the class
        simpleRegression.addData(new double[][]{
                {0, -1},
                {1, 0.2},
                {2, 0.9},
                {3, 2.1}
        });

        // querying for model parameters
        System.out.println("slope = " + simpleRegression.getSlope());
        System.out.println("intercept = " + simpleRegression.getIntercept());
    }
}

And of course you get the same result

slope = 1.0
intercept = -0.95

Upvotes: 4

Related Questions