Reputation: 35
I tried making a program for linear regression using gradient descent for some sample data. The theta values that I get do not give the best fit for the data. I have already normalized the data.
public class OneVariableRegression {
public static void main(String[] args) {
double x1[] = {-1.605793084, -1.436762233, -1.267731382, -1.098700531, -0.92966968, -0.760638829, -0.591607978, -0.422577127, -0.253546276, -0.084515425, 0.084515425, 0.253546276, 0.422577127, 0.591607978, 0.760638829, 0.92966968, 1.098700531, 1.267731382, 1.436762233, 1.605793084};
double y[] = {0.3, 0.2, 0.24, 0.33, 0.35, 0.28, 0.61, 0.38, 0.38, 0.42, 0.51, 0.6, 0.55, 0.56, 0.53, 0.61, 0.65, 0.68, 0.74, 0.87};
double theta0 = 0.5;
double theta1 = 0.5;
double temp0;
double temp1;
double alpha = 1.5;
double m = x1.length;
System.out.println(m);
double derivative0 = 0;
double derivative1 = 0;
do {
for (int i = 0; i < x1.length; i++) {
derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m);
derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];
}
temp0 = theta0 - (alpha * derivative0);
temp1 = theta1 - (alpha * derivative1);
theta0 = temp0;
theta1 = temp1;
//System.out.println("Derivative0 = " + derivative0);
//System.out.println("Derivative1 = " + derivative1);
}
while (derivative0 > 0.0001 || derivative1 > 0.0001);
System.out.println();
System.out.println("theta 0 = " + theta0);
System.out.println("theta 1 = " + theta1);
}
}
Upvotes: 2
Views: 371
Reputation: 55448
The derivative you're using comes from the squared error function, which is convex, hence accepts no local minimums other than the one global minimum. (In fact, this type of problem can even accepts a closed-form solution called the normal equation, it's just not numerically tractable for large problems, hence the use of gradient descent)
And the correct answer is around theta0 = 0.4895
and theta1 = 0.1652
, this is trivial to check on any statistical computing environment. (See bottom of answer if you're skeptical)
Below I point out the mistakes in your code, after fixing the mistakes, you'll get the correct answer above within 4 decimals places.
So you are right to expect it to converge global minimum, but you have problems in the implementation
Each time you recalculate the derivative_i
, you forgot to reset it to 0 (what you were doing was accumulating the derivative across iterations in the do{}while()
You need this in the do while loop
do {
derivative0 = 0;
derivative1 = 0;
...
}
Next is this
derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m);
derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];
The x1[i]
factor should be applied to the (theta0 + (theta1 * x1[i]) - y[i]))
alone.
Your attempt is slightly confusing, so let's write it in a clearer manner as below, which is a lot closer to its mathematical equation (1/m)sum(y_hat_i - y_i)x_i
:
// You need fresh vars, don't accumulate the derivatives across gradient descent iterations
derivative0 = 0;
derivative1 = 0;
for (int i = 0; i < m; i++) {
derivative0 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i]);
derivative1 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i])*x1[i];
}
That should get you close enough, however, I find your learning rate alpha to be a tad big. When it's too big, your gradient descent will have trouble zeroing in no your global optimum, it will hang around there, but won't quite be there.
double alpha = 0.5;
Run it and compare it to the answer from a statistics software
Here's a gist on github of your .java file.
➜ ~ javac OneVariableRegression.java && java OneVariableRegression
20.0
theta 0 = 0.48950064086914064
theta 1 = 0.16520139788757973
I compared it with R
> x
[1] -1.60579308 -1.43676223 -1.26773138 -1.09870053 -0.92966968 -0.76063883
[7] -0.59160798 -0.42257713 -0.25354628 -0.08451543 0.08451543 0.25354628
[13] 0.42257713 0.59160798 0.76063883 0.92966968 1.09870053 1.26773138
[19] 1.43676223 1.60579308
> y
[1] 0.30 0.20 0.24 0.33 0.35 0.28 0.61 0.38 0.38 0.42 0.51 0.60 0.55 0.56 0.53
[16] 0.61 0.65 0.68 0.74 0.87
> lm(y ~ x)
Call:
lm(formula = y ~ x)
Coefficients:
(Intercept) x
0.4895 0.1652
Now your code gives the correct answer to at least 4 decimals.
Upvotes: 2
Reputation: 77860
Yes, there's an error in your formulae. for some reason, you included derivative0 and 1 in the multiplications. This seriously skewed the results. Simply remove the extra parentheses and try again:
derivative0 = derivative0 + (theta0 + (theta1 * x1[i]) - y[i]) * (1/m);
derivative1 = derivative1 + (theta0 + (theta1 * x1[i]) - y[i]) * (1/m) * x1[i];
Output:
20.0
Derivative0 = 0.010499999999999995
Derivative1 = 0.31809711251208517
Derivative0 = 0.0052500000000000185
Derivative1 = 0.1829058398064968
Derivative0 = -0.007874999999999993
Derivative1 = -0.2129262545589219
theta 0 = 0.4881875
theta 1 = 0.06788495336050987
Is this more like what you expected?
Upvotes: 1