Lorenzo Bottaccioli
Lorenzo Bottaccioli

Reputation: 441

Piecewise regresion Python

Hi I'm trying to figure out how to fit those values with a piecewise linear function. I have read this question but I can't get forward (How to apply piecewise linear fit in Python? ). In this example is show how to implement a piecewise function for a 2 segment case. But I need to do it in a three segment case as in figure. Three segment data

I'have written this code:

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np


x1 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15,16,17,18,19,20,21], dtype=float)
y1 = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03,145,147,149,151,153,155])



x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03])

def piecewise(x,x0,x1,y0,y1,k0,k1,k2):
    return np.piecewise(x , [x <= x0, (x>= x1)] , [lambda x:k0*x + y0-k0*x0, lambda x:k1*(x-(x1+x0))-y1, lambda x:k2*x + y1-k2*x1])

p , e = optimize.curve_fit(piecewise_linear, x1, y1)
xd = np.linspace(0, 15, 100)
plt.figure()
plt.plot(x1, y1, "o")
plt.plot(xd, piecewise_linear(xd, *p))

but this is the output

enter image description here

Any suggestion? I belive that the problem is in return np.piecewise(x , [x <= x0, (x>= x1)] , [lambda x:k0*x + y0-k0*x0, lambda x:k1*(x-(x1+x0))-y1, lambda x:k2*x + y1-k2*x1]) in particular in the second lambda.

EDIT 1:

If I try to different data the solution provided by A.L. I don't get good results.

enter image description here

I get this result:

enter image description here

with

x=[ 16.01690476,  16.13801587,  14.63628571,  15.32664399,
        15.8145    ,  15.71507143,  15.56107143,  15.553     ,
        15.08734524,  14.97275   ,  15.51958333,  16.61981859,
        16.36589286,  14.78708333,  14.41565476,  13.47763158,
        13.42412281,  12.95551378,  13.66601504,  13.63315789,
        13.21463659,  13.53464286,  14.60130952,  14.7774881 ,
        13.04319048,  12.53385965,  12.65745614,  13.90535714,
        14.82412281,  14.6565    ,  15.09541667,  13.41434524,
        13.66033333,  14.57964286,  13.55416667,  13.43041667,
        13.01137566,  12.76429825,  11.55241667,  11.0634881 ,
        10.92729762,  11.21625   ,  10.72092857,  11.80380952,
        12.55233333,  12.11307143,  11.78892857,  12.45458333,
        11.05539286,  10.69214286,  10.32566667,  11.3439881 ,
         9.69563492,  10.72535714,  10.26180272,   7.77272727,
         6.37704082,   8.49666667,   8.5389881 ,   5.68547619,
         7.00616667,   8.22015873,  10.20315476,  15.35736842,
        12.25158333,  11.09622153,  10.4118254 ,   9.8602381 ,
        10.16727273,  15.10858333,  13.82215539,  12.44719298,
        10.92341667,  11.44565476,  11.43333333,  10.5045    ,
        11.14357143,  10.37625   ,   8.93421769,   9.48444444,
        10.43483333,  10.8659881 ,  10.96166667,  10.12872619,
         9.64663265,   9.29979762,   9.67173469,   8.978322  ,
         9.10419501,   9.45411565,  10.46411565,   7.95739229,
         8.72616667,   7.03892857,   7.32547619,   7.56441667,
         6.61022676,   9.09014739,  10.78141667,  10.85918367,
        11.11665476,  10.141     ,   9.17760771,   8.27968254,
        11.02625   ,  12.34809524,  11.17807018,  11.25416667,
        11.29236905,   9.28357143,   9.77033333,  11.52086168,
         9.8625    ,  12.60281955,  12.42785714,  12.11902256,
        13.1       ,  13.02791667,  13.87779449,  15.09857143,
        13.93935185,  13.69821429,  13.39880952,  12.45692982,
        12.76921053,  13.23708333,  13.71666667,  15.39807143,
        15.27916667,  14.66464286,  13.38694444,  10.97555556,
        10.02191667,  11.99608333,  14.26325   ,  15.40991667,
        15.12908333,  15.76265476,  12.12763158,  15.01641667,
        14.39602381,  12.98532143,  14.98807018,  18.30547619,
        16.7564966 ,  16.82982143,  19.8487013 ,  19.18600907]

and

y=[ 2.36846863,  2.73722628,  2.77177583,  2.63930636,  2.80864749,
        2.57066667,  2.65277287,  2.57162347,  2.76295667,  2.79835391,
        2.60431154,  2.17326401,  2.67740698,  2.47138153,  2.49882574,
        2.60987338,  2.69935565,  2.60755362,  2.77702029,  2.62996942,
        2.45959517,  2.52750434,  2.73833005,  2.52009   ,  2.80933226,
        1.63807085,  2.49230099,  2.55441614,  3.19256506,  2.52609288,
        1.02931596,  2.40266963,  2.3306463 ,  2.69094276,  2.60779985,
        2.48351648,  2.45131766,  2.40526763,  2.03952569,  1.86217009,
        1.79971848,  1.91772218,  1.85895421,  2.32725731,  2.28189713,
        2.11835833,  2.09636517,  2.2230303 ,  1.85863317,  1.77550406,
        1.68862391,  1.79187765,  1.70887476,  1.81911193,  1.74802483,
        1.65776432,  1.58012849,  1.67781494,  1.62451541,  1.60555884,
        1.56172214,  1.60083809,  1.65256994,  2.74794704,  2.27089627,
        1.80364982,  1.51412482,  1.77738757,  1.56979564,  2.46538633,
        2.37679625,  2.40389294,  2.04165763,  1.82086407,  1.90609219,
        1.87480978,  1.8877854 ,  1.76080074,  1.68369028,  1.57419297,
        1.66470126,  1.74522552,  1.72459756,  1.65510503,  1.72131148,
        1.6254417 ,  1.57091907,  1.68755268,  1.70307911,  1.59445121,
        1.74393783,  1.72913779,  1.66883237,  1.59859545,  1.62335831,
        1.73378184,  1.62621588,  1.79532164,  1.78289992,  1.79475101,
        1.7826266 ,  1.68778918,  1.64484127,  1.62332696,  1.75372393,
        1.99038021,  1.87268137,  1.86124502,  1.82435911,  1.62927102,
        1.66443723,  1.86743516,  1.62745098,  2.20200312,  2.09641026,
        2.26649111,  2.63271605,  2.18050721,  2.57138433,  2.51833359,
        2.74684184,  2.57209998,  2.63762019,  2.30027877,  2.28471286,
        2.40323668,  2.37103313,  2.16414489,  1.01027109,  2.64181007,
        2.45467765,  2.05773672,  1.73624917,  2.05233688,  2.70820669,
        2.65594222,  2.67445635,  2.37212985,  2.48221803,  2.77655216,
        2.62839879,  2.26481307,  2.58005799,  2.1188172 ,  2.14017268,
        2.16459571,  1.95083406,  1.46224418]

Upvotes: 4

Views: 3625

Answers (2)

chasmani
chasmani

Reputation: 2510

The piecewise-regression python library can fit models with different numbers of breakpoints.

First of all, for demonstration purposes generate some data with 2 breakpoints:

import numpy as np

gradients = [2.5,12,2]
constant = 0
breakpoints = [6, 15] 
n_points = 100
np.random.seed(1)
xx = np.linspace(0, 25, n_points)
yy = constant + gradients[0]*xx + np.random.normal(size=n_points)*10
for bp_n in range(len(breakpoints)):
    yy += (gradients[bp_n+1] - gradients[bp_n]) * np.maximum(xx - breakpoints[bp_n], 0)

To fit and plot the model:

import piecewise_regression
import matplotlib.pyplot as plt

pw_fit = piecewise_regression.Fit(xx, yy, n_breakpoints=2)
pw_fit.plot()
plt.xlabel("x")
plt.ylabel("y")
plt.show()

piecewise-regression

It also gives you a statistical analysis:

pw_fit.summary()

piecewise-regression summary

It won't work well with the data you provided in your edit, because there are outliers that dominate the error cost function. This will be an issue whichever method you use to fit the data, you need to decide how to handle the outliers in this instance.

Upvotes: 0

A. L.
A. L.

Reputation: 121

Fitting a piecewise linear function is a nonlinear optimization problem which may have local optimas. The result you see is probably one of the local optimas where your optimization algorithm gets stuck.

One way to solve this problem is to repeat your optimization algorithm with different initial values and take the best fit. I used the mean absolute error (MAE) to compare the different fits against each other.

perr = np.sum(np.abs(y1-piecewise(x1, *p)))

I also changed your piecewise funtion because it was a bit confusing for me. But it still a piecewise function as before

Further think you forgot to extend the x and xd array to the value of 21. (thats why the green line ends early).

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np


def piecewise(x,x0,x1,y0,y1,k0,k1,k2):
    return np.piecewise(x , [x <= x0, np.logical_and(x0<x, x<= x1),x>x1] , [lambda x:k0*x + y0, lambda x:k1*(x-x0)+y1+k0*x0,
                                                                            lambda x:k2*(x-x1) + y0+y1+k0*x0+k1*(x1-x0)])

x1 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15,16,17,18,19,20,21], dtype=float)
y1 = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03,145,147,149,151,153,155])
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15,16,17,18,19,20,21], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03,145,147,149,151,153,155])


perr_min = np.inf
p_best = None
for n in range(100):
    k = np.random.rand(7)*20
    p , e = optimize.curve_fit(piecewise, x1, y1,p0=k)
    perr = np.sum(np.abs(y1-piecewise(x1, *p)))
    if(perr < perr_min):
        perr_min = perr
        p_best = p

xd = np.linspace(0, 21, 100)
plt.figure()
plt.plot(x1, y1, "o")
y_out = piecewise(xd, *p_best)
plt.plot(xd, y_out)
plt.show()

this gives me: enter image description here

with p = [ 6.34259491 15.00000023 2.97272604 7.05498314 2.00751828 13.88881542 1.99960597]

Edit1

You edited your question, and this ist the answer to the edited one. Sorry Iam new at stackoverlfow and not sure if I should post another answer instead

In your second dataset you added noise to data. In my opinion there are two kinds of noises. A gaussian one, which places the points close to the underlying piecewise line and outlier noise which places points far away from the original underlying line.

Under the hood the optimization algorithm you use optimizes the following according to p: E = sum(square(y-piecewise(x,p))) http://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html#scipy.optimize.curve_fit

The gaussian noise is not very problematic. The optimization you use assumes indirectly this gaussian noise (by minimizing the least square error) and fits the line as good as possible. The real problem comes in with the outliers.

The problem is that outliers are far way from the original function. Even if the optimization tries the optimal parameters, the Energy function E will be not minimal, as your outliers are far away from the original function and this distance is even squared so it shifts away the minimum of the Function E far away from the true parameters of your function.

So whats the solution ? Get rid of the outliers.

An automized approach to to that is ransac https://en.wikipedia.org/wiki/RANSAC.

In Brief: You choose a random subset of the original data. You hope the subset has not outliers. You fit your function to the subset and discard the points, which are far way from the fitted function. If enough points survived this step, you take all the surviving points and repeat the fit. The error on this "inlier" set is a measure of the quality of your fit. Then you repeat the whole process and take the best final fit.

I ajusted my script accordingly:

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np
def piecewise(x,x0,x1,y0,y1,k0,k1,k2):
    return np.piecewise(x , [x <= x0, np.logical_and(x0<x, x<= x1),x>x1] , [lambda x:k0*x + y0, lambda x:k1*(x-x0)+y1+k0*x0,
                                                                            lambda x:k2*(x-x1) + y0+y1+k0*x0+k1*(x1-x0)])

x = np.array(x)
y = np.array(y)

x1 = x
y1 = y



perr_min = np.inf
p_best = None
for n in range(100):
    idx = np.random.choice(np.arange(len(x)), 10, replace=False)
    x_sample = x[idx]
    y_sample = y[idx]
    k = np.random.rand(7)*20
    try:
        p , e = optimize.curve_fit(piecewise, x_sample,y_sample ,p0=k)
        each_error = np.abs(y-piecewise(x, *p))
        x_inliner = x[each_error < 1]
        y_inlier = y[each_error < 1]
        if(x_inliner.shape[0] < 0.8 * x.shape[0]):
            continue

        p_inlier , e_inlier = optimize.curve_fit(piecewise, x_inliner,y_inlier ,p0=p)
        perr = np.sum(np.abs(y-piecewise(x, *p_inlier)))


        if(perr < perr_min):
            perr_min = perr
            p_best = p_inlier
    except RuntimeError:
        pass

xd = np.linspace(0, 21, 100)
plt.figure()
plt.plot(x, y, "o")
y_out = piecewise(xd, *p_best)
plt.plot(xd, y_out)
print p_best
plt.show()

With 100 repetitions I get the following result: Fittting the curve with Ransac and least squares

Upvotes: 3

Related Questions