charles
charles

Reputation: 733

Save scipy object to file

I want to save the object interpolatorgenerated from scipy.interpolate.InterpolatedUnivariateSpline to a file, in order to load it afterwards and use it. This is the result on the console:

>>> interpolator
 <scipy.interpolate.fitpack2.InterpolatedUnivariateSpline object at 0x11C27170>
np.save("interpolator",np.array(interpolator))
>>> f = np.load("interpolator.npy")
>>> f
array(<scipy.interpolate.fitpack2.InterpolatedUnivariateSpline object at 0x11C08FB0>, dtype=object)

These are the results trying to use the loaded interpolator f with a generic value:

>>>f(10)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'numpy.ndarray' object is not callable

or:

>>> f[0](10)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: too many indices for array

How can I save/load it properly?

Upvotes: 9

Views: 8443

Answers (2)

hpaulj
hpaulj

Reputation: 231415

The interpolator object is not an array, so np.save has wrapped it in an object array. And it falls back on pickle to save elements that aren't arrays. So you get back a 0d array with one object.

To illustrate with a simple dictionary object:

In [280]: np.save('test.npy',{'one':1})
In [281]: x=np.load('test.npy')
In [282]: x
Out[282]: array({'one': 1}, dtype=object)
In [283]: x[0]
...
IndexError: 0-d arrays can't be indexed
In [284]: x[()]
Out[284]: {'one': 1}
In [285]: x.item()
Out[285]: {'one': 1}
In [288]: x.item()['one']
Out[288]: 1

So either item or [()] will retrieve this object from the array. You should then be able to use it as you would before the save.

Using your own pickle calls is fine.

Upvotes: 8

Ed Smith
Ed Smith

Reputation: 13216

It looks like numpy.save and then numpy.load converts the scipy InterpolatedUnivariateSpline object to a numpy object. Numpy save/load apparently has an allow_pickle=True input to which should preserve object information. This isn't present in my version of numpy (1.9.2) and I assume maybe your version too. With spl=numpy.load("file"), the type information is lost and so the call spl as a method fails. As numpy save is mainly designed for binary arrays of data, the most general solution is probably to use pickle. As a minimum example,

import matplotlib.pyplot as plt
from scipy.interpolate import InterpolatedUnivariateSpline
import numpy as np
try:
    import cPickle as pickle
except ImportError:
    import pickle

x = np.linspace(-3, 3, 50)
y = np.exp(-x**2) + 0.1 * np.random.randn(50)
spl = InterpolatedUnivariateSpline(x, y)
plt.plot(x, y, 'ro', ms=5)


xs = np.linspace(-3, 3, 1000)

#Plot before save
plt.plot(xs, spl(xs), 'g', lw=3, alpha=0.7)

#Save, load and plot again (NOTE CAUSES ERROR)
#np.save("interpolator",spl)
#spl_loaded = np.load("interpolator.npy")
#plt.plot(xs, spl_loaded(xs), 'k--', lw=3, alpha=0.7)

#Pickle, unpickle and then plot again
with open('interpolator.pkl', 'wb') as f:
    pickle.dump(spl, f)
with open('interpolator.pkl', 'rb') as f:
    spl_loaded = pickle.load(f)
plt.plot(xs, spl_loaded(xs), 'k--', lw=3, alpha=0.7)

plt.show()

Upvotes: 4

Related Questions