Portable
Portable

Reputation: 323

What is wrong with my simple linear regression?

I fail to find the theta using this code.

I added plotting code to help visualize the issue.

Please, help me find the bug in this short block of code

Thanks

Normal Equation used in linear function

import numpy as np
import matplotlib.pyplot as plt

N = 20

def arr(n):
    return np.arange(n) + 1

def linear(features, y):
    x = np.vstack(features).T
    xT = np.transpose(x)
    xTx = xT.dot(x)
    return np.linalg.inv(xTx).dot(xT).dot(y)

def plot(x, y, dots_y):
    plt.plot(x, y)
    plt.plot(x, dots_y, marker='o', linestyle=' ', color='r')
    plt.show()    

y = arr(N) ** 2 + 3    
theta = linear((np.ones(N), arr(N), arr(N) ** 2), y)

plot(arr(N), arr(N) ** theta[1] + theta[0], y)

Output plot

Upvotes: 1

Views: 177

Answers (1)

user6655984
user6655984

Reputation:

The error is in the plotting line, which should be

plot(arr(N), arr(N)**2 * theta[2] + arr(N) * theta[1] + theta[0], y)

according to the quadratic polynomial model.

Also; I suppose you did the computation of least-square solution this way for expository reasons, but in practice, linear least squares fit would be obtained with np.linalg.lstsq as follows, with much shorter and more efficient code:

N = 20
x = np.arange(1, N+1)
y = x**2 + 3
basis = np.vstack((x**0, x**1, x**2)).T  # basis for the space of quadratic polynomials 
theta = np.linalg.lstsq(basis, y)[0]   # least squares approximation to y in this basis
plt.plot(x, y, 'ro')                   # original points
plt.plot(x, basis.dot(theta))          # best fit
plt.show()

fit

Upvotes: 1

Related Questions