Reputation: 733
I want to save the object interpolator
generated from scipy.interpolate.InterpolatedUnivariateSpline
to a file, in order to load it afterwards and use it.
This is the result on the console:
>>> interpolator
<scipy.interpolate.fitpack2.InterpolatedUnivariateSpline object at 0x11C27170>
np.save("interpolator",np.array(interpolator))
>>> f = np.load("interpolator.npy")
>>> f
array(<scipy.interpolate.fitpack2.InterpolatedUnivariateSpline object at 0x11C08FB0>, dtype=object)
These are the results trying to use the loaded interpolator f
with a generic value:
>>>f(10)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: 'numpy.ndarray' object is not callable
or:
>>> f[0](10)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: too many indices for array
How can I save/load it properly?
Upvotes: 9
Views: 8443
Reputation: 231415
The interpolator
object is not an array, so np.save
has wrapped it in an object
array. And it falls back on pickle
to save elements that aren't arrays. So you get back a 0d array with one object.
To illustrate with a simple dictionary object:
In [280]: np.save('test.npy',{'one':1})
In [281]: x=np.load('test.npy')
In [282]: x
Out[282]: array({'one': 1}, dtype=object)
In [283]: x[0]
...
IndexError: 0-d arrays can't be indexed
In [284]: x[()]
Out[284]: {'one': 1}
In [285]: x.item()
Out[285]: {'one': 1}
In [288]: x.item()['one']
Out[288]: 1
So either item
or [()]
will retrieve this object from the array. You should then be able to use it as you would before the save
.
Using your own pickle
calls is fine.
Upvotes: 8
Reputation: 13216
It looks like numpy.save
and then numpy.load
converts the scipy InterpolatedUnivariateSpline
object to a numpy object. Numpy save/load apparently has an allow_pickle=True
input to which should preserve object information. This isn't present in my version of numpy (1.9.2) and I assume maybe your version too. With spl=numpy.load("file")
, the type information is lost and so the call spl
as a method fails. As numpy save is mainly designed for binary arrays of data, the most general solution is probably to use pickle
. As a minimum example,
import matplotlib.pyplot as plt
from scipy.interpolate import InterpolatedUnivariateSpline
import numpy as np
try:
import cPickle as pickle
except ImportError:
import pickle
x = np.linspace(-3, 3, 50)
y = np.exp(-x**2) + 0.1 * np.random.randn(50)
spl = InterpolatedUnivariateSpline(x, y)
plt.plot(x, y, 'ro', ms=5)
xs = np.linspace(-3, 3, 1000)
#Plot before save
plt.plot(xs, spl(xs), 'g', lw=3, alpha=0.7)
#Save, load and plot again (NOTE CAUSES ERROR)
#np.save("interpolator",spl)
#spl_loaded = np.load("interpolator.npy")
#plt.plot(xs, spl_loaded(xs), 'k--', lw=3, alpha=0.7)
#Pickle, unpickle and then plot again
with open('interpolator.pkl', 'wb') as f:
pickle.dump(spl, f)
with open('interpolator.pkl', 'rb') as f:
spl_loaded = pickle.load(f)
plt.plot(xs, spl_loaded(xs), 'k--', lw=3, alpha=0.7)
plt.show()
Upvotes: 4