astrochris
astrochris

Reputation: 1866

How to make my pylab.poly1d(fit) pass through zero?

My code bellow produces a polyfit of the points in my graph, but I want this fit to always pass through zero, how do I do this?

import pylab as pl
import numpy as np

y=(abs((UX2-UY2)+(2*UXY)))
a=np.mean(y)
y=y-a
x=(abs((X2-Y2)+(2*XY)))
b=np.mean(x)
x=x-b
ax=pl.subplot(1,4,4) #plot XY
fit=pl.polyfit(x,y,1)
slope4, fit_fn=pl.poly1d(fit)
print slope4
fit_fn=pl.poly1d(fit)
x_min=-2
x_max=5
n=10000
x_fit = pl.linspace(x_min, x_max, n) 
y_fit = fit_fn(x_fit)
q=z=[-2,5]
scat=pl.plot(x,y, 'o', x_fit,y_fit, '-r', z, q, 'g' )

Upvotes: 4

Views: 2161

Answers (2)

Bitwise
Bitwise

Reputation: 7805

As was mentioned, you can't really do it explicitly with polyfit (but you can write your own function).

However, if you want to still use polyfit() you can try this math hack: add a point at zero, and then use the w flag (weights) in polyfit() to give it a high weight while all other points get a low weight. This will have the effect of forcing the polynomial to pass at zero or very close.

Upvotes: -1

Jaime
Jaime

Reputation: 67457

When you fit an n-degree polynomial p(x) = a0 + a1*x + a2*x**2 + ... + an*x**n to a set of data points (x0, y0), (x1, y1), ..., (xm, y_m), a call to np.lstsq is made with a coefficient matrix that looks like:

[1 x0 x0**2 ... x0**n]
[1 x1 x1**2 ... x1**n]
...               
[1 xm xm**2 ... xm**n]

If you remove the j-th column from that matrix, you are effectively setting that coefficient in the polynomial to 0. So to get rid of the a0 coefficient you could do the following:

def fit_poly_through_origin(x, y, n=1):
    a = x[:, np.newaxis] ** np.arange(1, n+1)
    coeff = np.linalg.lstsq(a, y)[0]
    return np.concatenate(([0], coeff))

n = 1000
x = np.random.rand(n)
y = 1 + 3*x - 4*x**2 + np.random.rand(n)*0.25

c0 = np.polynomial.polynomial.polyfit(x, y, 2)
c1 = fit_poly_through_origin(x, y, 2)

p0 = np.polynomial.Polynomial(c0)
p1 = np.polynomial.Polynomial(c1)

plt.plot(x, y, 'kx')
xx = np.linspace(0, 1, 1000)
plt.plot(xx, p0(xx), 'r-', )
plt.plot(xx, p1(xx), 'b-', )

enter image description here

Upvotes: 6

Related Questions