VivekS
VivekS

Reputation: 13

How can I improve response time of RAG chain developed with Ollama, langchain and pgvector?

I have created a RAG app using Ollama, Langchain and pgvector. It takes about 4-5 seconds to retrieve an answer from llama3.1:7b model.

I have followed Langchain documentation and added profiling to my code. The output of profiling is as follows Sentry profile showing the response times

I tried using concurrent.futures but that didn't make any difference.

I amended my postgresql.config file to allow using additional RAM. My PC has Nvidia A6000.

My code is as follows

import os
from dotenv import load_dotenv
from langchain_ollama import ChatOllama
from langchain_community.document_loaders import (
    UnstructuredWordDocumentLoader,
    TextLoader,
)
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.embeddings import OllamaEmbeddings
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from typing import Dict, List
from langchain_postgres.vectorstores import PGVector
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import sentry_sdk
import sentry_sdk.profiler
from concurrent.futures import ThreadPoolExecutor


class RAGApp:
    def __init__(self):
        load_dotenv()

        sentry_sdk.init(
            dsn="https://928cdfed3291d120bd9972df53e4d90d@o4505835180130304.ingest.us.sentry.io/4508009613164544",
            # Set traces_sample_rate to 1.0 to capture 100%
            # of transactions for tracing.
            traces_sample_rate=1.0,
            # Set profiles_sample_rate to 1.0 to profile 100%
            # of sampled transactions.
            # We recommend adjusting this value in production.
            profiles_sample_rate=1.0,
        )

        # Constants
        self.BASE_URL = os.getenv("BASE_URL")
        self.LLM_MODEL = os.getenv("LLM_MODEL")
        self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
        self.DOCS_FOLDER = os.getenv("DOCS_FOLDER")
        self.DEFAULT_DOCUMENT = os.getenv("DEFAULT_DOCUMENT")
        self.CONNECTION_STRING = os.getenv("CONNECTION_STRING")

        self.COLLECTION_NAME = "exaba_llm_history"

    def create_knowledge_base(self, document_path):
        try:
            loader = UnstructuredWordDocumentLoader(document_path)
            documents = loader.load()

            text_splitter = CharacterTextSplitter(
                separator="\n", chunk_size=1000, chunk_overlap=100
            )
            text_chunks = text_splitter.split_documents(documents)

            embeddings = OllamaEmbeddings(
                base_url=self.BASE_URL, model=self.EMBEDDING_MODEL
            )
            knowledge_base = Chroma.from_documents(text_chunks, embeddings)

            return knowledge_base
        except Exception as e:
            return {"error": f"Error creating knowledge base: {e}"}

    def create_knowledge_base_pg(self, document_path):
        try:
            sentry_sdk.profiler.start_profiler()

            with sentry_sdk.start_transaction(op="task", name="Create knowledge"):
                with sentry_sdk.start_span(description="Store knowledge in pgvector"):
                    connection = self.CONNECTION_STRING
                    collection_name = self.COLLECTION_NAME

                    # loader = UnstructuredWordDocumentLoader(document_path)
                    if not os.path.exists(document_path):
                        return {
                            "error": f"Document path does not exist: {document_path}"
                        }

                    with open(document_path, "r", encoding="utf-8") as file:
                        data = file.read()
                        # print(data)  # Check if the file loads correctly

                    loader = TextLoader(document_path, encoding="utf-8")
                    documents = loader.load()

                    text_splitter = CharacterTextSplitter(
                        separator="\n", chunk_size=1000, chunk_overlap=100
                    )

                    text_chunks = text_splitter.split_documents(documents)

                    embeddings = OllamaEmbeddings(
                        base_url=self.BASE_URL, model=self.EMBEDDING_MODEL
                    )
                    # knowledge_base = Chroma.from_documents(text_chunks, embeddings)

                    vector_store = PGVector.from_documents(
                        embedding=embeddings,
                        connection=connection,
                        collection_name=collection_name,
                        documents=text_chunks,
                        use_jsonb=True,
                        create_extension=True,
                        embedding_length=768,
                        distance_strategy="cosine",
                        engine_args={"pool_size": 10, "max_overflow": 20},
                    )

                    print("vector store created")

            sentry_sdk.profiler.stop_profiler()

            return vector_store
        except Exception as e:
            print(f"Error creating knowledge base: {e}")
            return {"error": f"Error creating knowledge base: {e}"}

    def get_knowledge_base_pg(self):
        try:

            sentry_sdk.profiler.start_profiler()
            with sentry_sdk.start_transaction(op="task", name="Get knowledge"):
                with sentry_sdk.start_span(description="Get knowledge from pgvector"):
                    connection = self.CONNECTION_STRING
                    collection_name = self.COLLECTION_NAME

                    embeddings = OllamaEmbeddings(
                        base_url=self.BASE_URL, model=self.EMBEDDING_MODEL
                    )

                    knowledge_base = PGVector(
                        connection=connection,
                        collection_name=collection_name,
                        embeddings=embeddings,
                        embedding_length=768,
                        distance_strategy="cosine",
                        engine_args={"pool_size": 10, "max_overflow": 20},
                    )
                sentry_sdk.profiler.stop_profiler()

                return knowledge_base

        except Exception as e:
            print(f"Error while getting knowledge base: {e}")
            return {"error": f"Error while getting knowledge base: {e}"}

    def run_retrieval_chain(
        self,
        question: str,
        chat_history: any,
        llm_model=None,
        base_url=None,
    ) -> any:
        try:
            llm_model = llm_model or self.LLM_MODEL
            base_url = base_url or self.BASE_URL

            sentry_sdk.profiler.start_profiler()

            with sentry_sdk.start_transaction(op="task", name="Run retrieval chain"):
                with sentry_sdk.start_span(description="Generate answer"):
                    # Set up the LLM
                    llm = ChatOllama(
                        base_url=base_url,
                        model=llm_model,
                        temperature=0,
                        callbacks=CallbackManager([StreamingStdOutCallbackHandler()]),
                    )

                    # Create the knowledge base
                    # This should be replaced with a more robust way of creating the knowledge base in the future
                    # default_document_path = os.path.join(
                    #     self.DOCS_FOLDER, self.DEFAULT_DOCUMENT
                    # )

                    # knowledge_base = self.create_knowledge_base(default_document_path)

                    # create_knowledge_base_pg = self.create_knowledge_base_pg(default_document_path)

                    knowledge_base_pg = self.get_knowledge_base_pg()

                    if not knowledge_base_pg:
                        return {"answer": "Knowledge base missing."}

                    # Create retriever
                    with sentry_sdk.start_span(description="Create retriever"):
                        retriever_pg = knowledge_base_pg.as_retriever(
                            search_kwargs={"k": 3}
                        )

                    system_prompt = (
                        "You are an assistant for question-answering tasks. "
                        "Use the following pieces of retrieved context to answer "
                        "the question. If you don't know the answer, say that you "
                        "don't know."
                        "\n\n"
                        "{context}"
                    )

                    raw_prompt = ChatPromptTemplate.from_messages(
                        [
                            ("system", system_prompt),
                            ("human", "{input}"),
                        ]
                    )

                    retriever_prompt = ChatPromptTemplate.from_messages(
                        [
                            MessagesPlaceholder(variable_name="chat_history"),
                            ("system", system_prompt),
                            ("human", "{input}"),
                        ]
                    )

                    question_answer_chain = create_stuff_documents_chain(
                        llm, raw_prompt
                    )

                    # Create history-aware retriever
                    history_aware_retriever = create_history_aware_retriever(
                        llm=llm, retriever=retriever_pg, prompt=retriever_prompt
                    )

                    # Create and run RAG chain
                    rag_chain = create_retrieval_chain(
                        history_aware_retriever, question_answer_chain
                    )

                    with sentry_sdk.start_span(description="Stream rag chain"):
                        chain = rag_chain.pick("answer")
                        for chunk in chain.stream({"input": question}):
                            yield chunk

                    yield chunk

            sentry_sdk.profiler.stop_profiler()

        except Exception as e:
            print(e)
            return {"answer": f"An error occurred while generating the answer. {e}"}

Upvotes: 0

Views: 212

Answers (0)

Related Questions