Jagdeesh R
Jagdeesh R

Reputation: 41

Generalised additive model - Python

I'm trying to fit a non linear model using Generalized Additive model. How do I determine the number of splines to use. Is there a specific way to choose the number of splines? I have used a 3rd order (cubic) spline fitting. Below is the code.

from pygam import LinearGAM
from pygam.utils import generate_X_grid

# Curve fitting using GAM model - Penalised spline curve.
def modeltrain(time,value):
    return LinearGAM(n_splines=58,spline_order=3).gridsearch(time, value)

model=modeltrain(t1,x1)

# samples random x-values for prediction
XX = generate_X_grid(model)

#plots for vizualisation
plt.plot(XX, model.predict(XX), 'r--')
plt.plot(XX, model.prediction_intervals(XX,width=0.25), color='b', ls='--    ')
plt.scatter(t1, x1)
plt.show()

This is the expected result

enter image description here

Original data scatter plot

enter image description here

If the number of splines is not chosen correctly, then I get a incorrect fit.

Please, I would like a suggestion of methods to choose the number of splines accurately.

Upvotes: 2

Views: 5992

Answers (1)

dswah
dswah

Reputation: 141

Typically for splines you choose a fairly high number of splines (~25) and you let the lambda smoothing parameter do the work of reducing the flexibility of the model.

For your use-case I would choose the default n_splines=25 and then do a gridsearch over the lambda parameter lam to find the best amount of smoothing:

def modeltrain(time,value):
    return LinearGAM(n_splines=25,spline_order=3).gridsearch(time, value, lam=np.logspace(-3, 3, 11))

This will try 11 models from lam = 1e-3 to 1e3.

I think your choice of n_splines=58 is too high because it looks like it produces one spline per data-point.

If you really want to do a search over n_splines then you could do:

LinearGAM(n_splines=25,spline_order=3).gridsearch(time, value, n_splines=np.arange(50))

Note: the function generate_X_grid does NOT do random sampling for prediction, it actually just makes a dense linear-spacing of your X-values (time). The reason for this is to visualize how the learned model will interpolate.

Upvotes: 4

Related Questions