Reputation: 323
I fail to find the theta using this code.
I added plotting code to help visualize the issue.
Please, help me find the bug in this short block of code
Thanks
import numpy as np
import matplotlib.pyplot as plt
N = 20
def arr(n):
return np.arange(n) + 1
def linear(features, y):
x = np.vstack(features).T
xT = np.transpose(x)
xTx = xT.dot(x)
return np.linalg.inv(xTx).dot(xT).dot(y)
def plot(x, y, dots_y):
plt.plot(x, y)
plt.plot(x, dots_y, marker='o', linestyle=' ', color='r')
plt.show()
y = arr(N) ** 2 + 3
theta = linear((np.ones(N), arr(N), arr(N) ** 2), y)
plot(arr(N), arr(N) ** theta[1] + theta[0], y)
Upvotes: 1
Views: 177
Reputation:
The error is in the plotting line, which should be
plot(arr(N), arr(N)**2 * theta[2] + arr(N) * theta[1] + theta[0], y)
according to the quadratic polynomial model.
Also; I suppose you did the computation of least-square solution this way for expository reasons, but in practice, linear least squares fit would be obtained with np.linalg.lstsq
as follows, with much shorter and more efficient code:
N = 20
x = np.arange(1, N+1)
y = x**2 + 3
basis = np.vstack((x**0, x**1, x**2)).T # basis for the space of quadratic polynomials
theta = np.linalg.lstsq(basis, y)[0] # least squares approximation to y in this basis
plt.plot(x, y, 'ro') # original points
plt.plot(x, basis.dot(theta)) # best fit
plt.show()
Upvotes: 1