Reputation: 12792
I create a package connecting to other libraries (livelossplot). It has a lot of optional dependencies (deep learning frameworks), and I don't want to force people to install them.
Right now I use conditional imports, in the spirit of:
try:
from .keras_plot import PlotLossesKeras
except ImportError:
# import keras plot only if there is keras
pass
However, it means that it imports big libraries, even if one does not intend to use them. The question is: how to import libraries only when one creates a particular object?
For Python functions, it is simple:
def function_using_keras():
import keras
...
What is a good practice for classes inheriting from other classes?
It seems that a parent class needs to be imported before defining an object:
from keras.callbacks import Callback
class PlotLossesKeras(Callback):
...
Upvotes: 7
Views: 1321
Reputation: 396
If you want to use Nils Werner's answer in a transparent way, you can do the following:
# mylib/__init__.py
def generic_stuff():
pass
def PlotLossesKeras(*args, **kwargs):
global PlotLossesKeras
from .keras import PlotLosses as PlotLossesKeras
return PlotLossesKeras(*args, **kwargs)
def PlotLossesTensorflow(*args, **kwargs):
global PlotLossesTensorflow
from .keras import PlotLosses as PlotLossesTensorflow
return PlotLossesTensorflow(*args, **kwargs)
These are 2 functions that, when called, import the required module, and then replace themselves with the class you want.
This is a bit hacky, but can be useful if you don't want to change the API.
Upvotes: 0
Reputation: 36749
The most straighforward and most easily understood solution would be to split your library into submodules.
It has several advantages over trying to do imports on object initialization:
import my_lib.keras
is very likely to depend on keras
import my_lib.keras
to import my_lib.tensorflow
Such a solution could look like
# mylib/__init__.py
class SomethingGeneric():
pass
def something_else():
pass
and then
# mylib/keras.py
import keras
class PlotLosses():
pass
and
# mylib/tensorflow.py
import tensorflow
class PlotLosses():
pass
Upvotes: 4