Ngọc Minh
Ngọc Minh

Reputation: 1

Why my nlp model reload many times when processing question?

After receiving question, my program calls the run_predict function then finds the best paragraph match with the question. After that, my model is constantly reloaded without knowing the reasons.

from flask import Flask, render_template, request, jsonify
from flask_socketio import SocketIO, emit
import os
import json
import logging
from simpletransformers.question_answering import QuestionAnsweringModel, QuestionAnsweringArgs
from multiprocessing import freeze_support
from models.find_top_paragraphs import main as find_top_paragraphs

app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
socketio = SocketIO(app)

# Configure the model
model_args = QuestionAnsweringArgs()
model_args.eval_batch_size = 16

# Get the absolute path of the current directory
current_dir = os.path.abspath(os.getcwd())

# Path to the outputs directory
outputs_dir = os.path.join(current_dir, "outputs", "best_model")

print(f"Model directory: {outputs_dir}")

# Global variable to store the model
model = None

def load_model():
    global model
    if model is None:
        model = QuestionAnsweringModel(
            model_type="bert", 
            model_name=outputs_dir, 
            args=model_args,
            use_cuda=False  # Set to True if you have a GPU and want to use it
        )
        print("Load model successfully!")
    else:
        print("Model is already loaded.")
    return model

# Load the model only once and reuse it
model = load_model()

def run_predict(question):
    print("Running predict function")
    # Directly call the find_top_paragraphs function
    find_top_paragraphs(question)

    # Read the result from the top_paragraphs.json file
    output_path = os.path.join(os.path.dirname(__file__), "models", "top_paragraphs.json")
    with open(output_path, "r", encoding='utf-8') as file:
        data = json.load(file)
    
    question = data["question"]
    top_paragraph = data["top_paragraphs"][0]  # Only take the most relevant paragraph
    print(f"Top Paragraph: {top_paragraph}")

    # Make a prediction with the most relevant paragraph
    to_predict = [
        {
            "context": top_paragraph,
            "qas": [
                {
                    "question": question,
                    "id": "0",
                }
            ],
        }
    ]

    # Get the global model
    model = load_model()

    # Make a prediction
    answers, probabilities = model.predict(to_predict)

    # Display the prediction results
    all_answers = []
    for answer in answers:
        for a in answer['answer']:
            all_answers.append(a)
        
        # Find the best answer
        try:
            probability = probabilities[0]['probability']
            best_answer_idx = probability.index(max(probability))
            best_answer = answer['answer'][best_answer_idx]
            print(f"Best Answer: {best_answer}")
            return all_answers, best_answer
        except Exception as e:
            print(f"An error occurred while selecting the best answer: {e}")
            return all_answers, None

@app.route('/')
def index():
    return render_template('index.html')

@socketio.on('send_message')
def handle_message(data):
    question = data['message']
    all_answers, best_answer = run_predict(question)
    response = {
        'all_answers': all_answers,
        'best_answer': best_answer
    }
    emit('receive_message', response)

if __name__ == '__main__':
    freeze_support()
    socketio.run(app, debug=False)

my terminal when processing the question

I tried to load the model in a separate file and then use it in my current file but the problem is still there. Maybe my code has some problems but i cant really figure it out

Upvotes: 0

Views: 31

Answers (1)

EliasK93
EliasK93

Reputation: 3174

You are calling load_model() once directly and once again in run_predict(), the latter one shouldn't be necessary.

I did the exact same thing as you before (using a simpletransformers model in Flask) and the way I solved it was to initialize some model utilities class (in a separate file) that in its constructor loads the model and then provides methods to use it. Then initialize such a class instance once like model_utilities = load_model_utilities(). Check out https://github.com/EliasK93/debertav3-for-aspect-based-sentiment-analysis/blob/master/aspect_based_sentiment_analysis/web_app.py. This way I never had any issues with the model being loaded more than once.

Upvotes: 0

Related Questions