Reputation: 13
I have created a RAG app using Ollama, Langchain and pgvector. It takes about 4-5 seconds to retrieve an answer from llama3.1:7b model.
I have followed Langchain documentation and added profiling to my code. The output of profiling is as follows
I tried using concurrent.futures but that didn't make any difference.
I amended my postgresql.config file to allow using additional RAM. My PC has Nvidia A6000.
My code is as follows
import os
from dotenv import load_dotenv
from langchain_ollama import ChatOllama
from langchain_community.document_loaders import (
UnstructuredWordDocumentLoader,
TextLoader,
)
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.embeddings import OllamaEmbeddings
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from typing import Dict, List
from langchain_postgres.vectorstores import PGVector
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import sentry_sdk
import sentry_sdk.profiler
from concurrent.futures import ThreadPoolExecutor
class RAGApp:
def __init__(self):
load_dotenv()
sentry_sdk.init(
dsn="https://928cdfed3291d120bd9972df53e4d90d@o4505835180130304.ingest.us.sentry.io/4508009613164544",
# Set traces_sample_rate to 1.0 to capture 100%
# of transactions for tracing.
traces_sample_rate=1.0,
# Set profiles_sample_rate to 1.0 to profile 100%
# of sampled transactions.
# We recommend adjusting this value in production.
profiles_sample_rate=1.0,
)
# Constants
self.BASE_URL = os.getenv("BASE_URL")
self.LLM_MODEL = os.getenv("LLM_MODEL")
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
self.DOCS_FOLDER = os.getenv("DOCS_FOLDER")
self.DEFAULT_DOCUMENT = os.getenv("DEFAULT_DOCUMENT")
self.CONNECTION_STRING = os.getenv("CONNECTION_STRING")
self.COLLECTION_NAME = "exaba_llm_history"
def create_knowledge_base(self, document_path):
try:
loader = UnstructuredWordDocumentLoader(document_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(
separator="\n", chunk_size=1000, chunk_overlap=100
)
text_chunks = text_splitter.split_documents(documents)
embeddings = OllamaEmbeddings(
base_url=self.BASE_URL, model=self.EMBEDDING_MODEL
)
knowledge_base = Chroma.from_documents(text_chunks, embeddings)
return knowledge_base
except Exception as e:
return {"error": f"Error creating knowledge base: {e}"}
def create_knowledge_base_pg(self, document_path):
try:
sentry_sdk.profiler.start_profiler()
with sentry_sdk.start_transaction(op="task", name="Create knowledge"):
with sentry_sdk.start_span(description="Store knowledge in pgvector"):
connection = self.CONNECTION_STRING
collection_name = self.COLLECTION_NAME
# loader = UnstructuredWordDocumentLoader(document_path)
if not os.path.exists(document_path):
return {
"error": f"Document path does not exist: {document_path}"
}
with open(document_path, "r", encoding="utf-8") as file:
data = file.read()
# print(data) # Check if the file loads correctly
loader = TextLoader(document_path, encoding="utf-8")
documents = loader.load()
text_splitter = CharacterTextSplitter(
separator="\n", chunk_size=1000, chunk_overlap=100
)
text_chunks = text_splitter.split_documents(documents)
embeddings = OllamaEmbeddings(
base_url=self.BASE_URL, model=self.EMBEDDING_MODEL
)
# knowledge_base = Chroma.from_documents(text_chunks, embeddings)
vector_store = PGVector.from_documents(
embedding=embeddings,
connection=connection,
collection_name=collection_name,
documents=text_chunks,
use_jsonb=True,
create_extension=True,
embedding_length=768,
distance_strategy="cosine",
engine_args={"pool_size": 10, "max_overflow": 20},
)
print("vector store created")
sentry_sdk.profiler.stop_profiler()
return vector_store
except Exception as e:
print(f"Error creating knowledge base: {e}")
return {"error": f"Error creating knowledge base: {e}"}
def get_knowledge_base_pg(self):
try:
sentry_sdk.profiler.start_profiler()
with sentry_sdk.start_transaction(op="task", name="Get knowledge"):
with sentry_sdk.start_span(description="Get knowledge from pgvector"):
connection = self.CONNECTION_STRING
collection_name = self.COLLECTION_NAME
embeddings = OllamaEmbeddings(
base_url=self.BASE_URL, model=self.EMBEDDING_MODEL
)
knowledge_base = PGVector(
connection=connection,
collection_name=collection_name,
embeddings=embeddings,
embedding_length=768,
distance_strategy="cosine",
engine_args={"pool_size": 10, "max_overflow": 20},
)
sentry_sdk.profiler.stop_profiler()
return knowledge_base
except Exception as e:
print(f"Error while getting knowledge base: {e}")
return {"error": f"Error while getting knowledge base: {e}"}
def run_retrieval_chain(
self,
question: str,
chat_history: any,
llm_model=None,
base_url=None,
) -> any:
try:
llm_model = llm_model or self.LLM_MODEL
base_url = base_url or self.BASE_URL
sentry_sdk.profiler.start_profiler()
with sentry_sdk.start_transaction(op="task", name="Run retrieval chain"):
with sentry_sdk.start_span(description="Generate answer"):
# Set up the LLM
llm = ChatOllama(
base_url=base_url,
model=llm_model,
temperature=0,
callbacks=CallbackManager([StreamingStdOutCallbackHandler()]),
)
# Create the knowledge base
# This should be replaced with a more robust way of creating the knowledge base in the future
# default_document_path = os.path.join(
# self.DOCS_FOLDER, self.DEFAULT_DOCUMENT
# )
# knowledge_base = self.create_knowledge_base(default_document_path)
# create_knowledge_base_pg = self.create_knowledge_base_pg(default_document_path)
knowledge_base_pg = self.get_knowledge_base_pg()
if not knowledge_base_pg:
return {"answer": "Knowledge base missing."}
# Create retriever
with sentry_sdk.start_span(description="Create retriever"):
retriever_pg = knowledge_base_pg.as_retriever(
search_kwargs={"k": 3}
)
system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know."
"\n\n"
"{context}"
)
raw_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
retriever_prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder(variable_name="chat_history"),
("system", system_prompt),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(
llm, raw_prompt
)
# Create history-aware retriever
history_aware_retriever = create_history_aware_retriever(
llm=llm, retriever=retriever_pg, prompt=retriever_prompt
)
# Create and run RAG chain
rag_chain = create_retrieval_chain(
history_aware_retriever, question_answer_chain
)
with sentry_sdk.start_span(description="Stream rag chain"):
chain = rag_chain.pick("answer")
for chunk in chain.stream({"input": question}):
yield chunk
yield chunk
sentry_sdk.profiler.stop_profiler()
except Exception as e:
print(e)
return {"answer": f"An error occurred while generating the answer. {e}"}
Upvotes: 0
Views: 212