Reputation: 177
I am trying to make an server which predicts (regression) given a certain input, however when I make a shared keras (with tensorflow backend) file to preload and skip loading the model every time (which would save about 1.8 seconds), and when I try to predict anything from a thread the program just freezes (even though only one thread is accessing it during my test). I know that the tensor is not made for this, however as it is only predicting there should be a workaround for this. I have tried using _make_prediction_function but that did not work.
This is the main function:
keras_model = keras_model_for_threads()
def thread_function(conn, addr, alive):
print('Connected by', addr)
start = time.time()
sent = conn.recv(1024)
x_pred = preproc(sent)
conn.sendall(keras_model.predict_single(x_pred))
conn.close()
import socket
HOST = ''
PORT = xxxxx
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((HOST, PORT))
s.listen(1000)
print('Ready for listening')
while alive.get():
conn, addr = s.accept()
Process(target=thread_function, args=(conn, addr, alive)).start()
with the keras_model:
class keras_model_for_threads():
def __init__(self):
self.model = load_model(model_path)
self.model._make_predict_function()
def predict_single(self, x_pred):
return self.model.predict(x_pred)
Now if I run this normally, it executes and returns a prediction but through the Process with the thread function it freezes on the self.model.predict.
Upvotes: 0
Views: 1174
Reputation: 177
After some more searching I found an answer which works, namely making a manager to handle the prediction. This changes the original keras code to:
from multiprocessing.managers import BaseManager
from multiprocessing import Lock
class KerasModelForThreads():
def __init__(self):
self.lock = Lock()
self.model = None
def load_model(self):
from keras.models import load_model
self.model = load_model(model_path)
def predict_single(self, x_pred):
with self.lock:
return (self.model.predict(x_pred) + self.const.offset)[0][0]
class KerasManager(BaseManager):
pass
KerasManager.register('KerasModelForThreads', KerasModelForThreads)
And the main code to
from keras_for_threads import KerasManager
keras_manager = KerasManager()
keras_manager.start()
keras_model = keras_manager.KerasModelForThreads()
keras_model.load_model()
def thread_function(conn, addr, alive):
print('Connected by', addr)
start = time.time()
sent = conn.recv(1024)
x_pred = preproc(sent)
conn.sendall(keras_model.predict_single(x_pred))
conn.close()
import socket
HOST = ''
PORT = xxxxx
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((HOST, PORT))
s.listen(1000)
print('Ready for listening')
while alive.get():
conn, addr = s.accept()
Process(target=thread_function, args=(conn, addr, alive)).start()
This is a stripped down version (without the Flask stuff, just the keras part) from the github project I found here:https://github.com/Dref360/tuto_keras_web
Upvotes: 1