Ben
Ben

Reputation: 377

Batch make_smoothing_spline in scipy

In scipy, the function scipy.interpolate.make_interp_spline() can be batched since its x argument must be one-dimensional with shape (m,) and its y argument can have shape (m, ...).

However, the function scipy.interpolate.make_smoothing_spline() only accepts a y argument of shape (m,).

Is there a simple way to batch the behavior of make_smoothing_spline() so it has the same behavior as make_interp_spline()?

I was thinking of using numpy.vectorize(), but here I'm not batching operations on an array, I need a single function as output.

I guess I could just implement a loop and make a nested list of splines, but I was wondering if there would be a neater way.

Probably some combination of decorators but I'm twisting my brain in knots...

EDIT: Developers seem to be aware of this issue here.

Upvotes: 3

Views: 65

Answers (1)

Matt Haberland
Matt Haberland

Reputation: 3873

The PR that added batch support to make_smoothing_spline happened to be merged a few hours before this post. https://github.com/scipy/scipy/pull/22484

The feature will be available in SciPy 1.16, or you can get it early in the next nightly wheels. https://anaconda.org/scientific-python-nightly-wheels/scipy

See also the BatchSpline class used in the tests of that PR.

class BatchSpline:
    # BSpline-like class with reference batch behavior
    def __init__(self, x, y, axis, *, spline, **kwargs):
        y = np.moveaxis(y, axis, -1)
        self._batch_shape = y.shape[:-1]
        self._splines = [spline(x, yi, **kwargs) for yi in y.reshape(-1, y.shape[-1])]
        self._axis = axis

    def __call__(self, x):
        y = [spline(x) for spline in self._splines]
        y = np.reshape(y, self._batch_shape + x.shape)
        return np.moveaxis(y, -1, self._axis) if x.shape else y

    def integrate(self, a, b, extrapolate=None):
        y = [spline.integrate(a, b, extrapolate) for spline in self._splines]
        return np.reshape(y, self._batch_shape)

    def derivative(self, nu):
        res = copy.deepcopy(self)
        res._splines = [spline.derivative(nu) for spline in res._splines]
        return res

    def antiderivative(self, nu):
        res = copy.deepcopy(self)
        res._splines = [spline.antiderivative(nu) for spline in res._splines]
        return res

Upvotes: 2

Related Questions