Reputation: 129
I have a Deep learning model and i want to load that model only once when my api start. Right now model is loading for every request made in my api service with the test image, and it is taking fair bit of time to print an output.
I tried loading my model with fastapi @app.on_event('startup')
main.py
@app.on_event('startup')
def init_data():
print("init call")
model = load_model(model_path)
return model
and I want to import this model variable to another python file classification.py
#classification.py
pred = model.predict(image)
I am not sure if this is the correct way to do this. Any help regarding this will be appreciated.
Thanks
Upvotes: 4
Views: 6444
Reputation: 1194
So what you need, it is just to encapsulate this in some functions and pass the model along.
# main.py
import classification
# Other imports...
model = load_model(model_path)
# Model is global so loaded once.
# ...
@app.get('/whatever/page')
def predict(image):
global model # We make sure model is not local
# We call classification's predict function
result = classification.predict(image, model)
#...
#classification.py
# import ...
# The function predict takes the model as a param, and the image to predict.
def predict(image, model):
pred = model.predict(image)
# ...
return pred
You should do this better with typing of the parameters, etc. but that gives an idea of the dataflow. Another approach would be to load the model in the root of classification.py
so you don't need to pass it along. That will be executed only once when imported the first time.
Upvotes: 2