raju
raju

Reputation: 6936

Connect Chainlit to existing ChromaDb

I am trying to create a RAG application using chainlit.

This is the code, I got from an existing tutorial, which is working fine. Only problem that the user has to choose a pdf file every time. I want that chainlit is connected with a persistent chroma vectordb, which should be created only once for all users.

from typing import List
import PyPDF2
from io import BytesIO
from langchain_community.embeddings import OllamaEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import (
    ConversationalRetrievalChain,
)
#from langchain_community.llms import Ollama
from langchain.docstore.document import Document
from langchain_community.llms import Ollama

from langchain_community.chat_models import ChatOllama

from langchain.memory import ChatMessageHistory, ConversationBufferMemory

import chainlit as cl

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)


@cl.on_chat_start
async def on_chat_start():
    files = None

    # Wait for the user to upload a file
    while files is None:
        files = await cl.AskFileMessage(
            content="Please upload a pdf file to begin!",
            accept=["application/pdf"],
            max_size_mb=20,
            timeout=180,
            max_files=2,
        ).send()

    file = files[0]
    print(file)

    msg = cl.Message(content=f"Processing `{file.name}`...")
    await msg.send()

    # Read the PDF file
        
    #pdf_stream = BytesIO(content)
    pdf = PyPDF2.PdfReader(file.path)
    pdf_text = ""
    for page in pdf.pages:
        pdf_text += page.extract_text()
        

    # Split the text into chunks
    texts = text_splitter.split_text(pdf_text)

    # Create a metadata for each chunk
    metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]

    # Create a Chroma vector store
    embeddings = OllamaEmbeddings(model="nomic-embed-text")
    docsearch = await cl.make_async(Chroma.from_texts)(
        texts, embeddings, metadatas=metadatas
    )

    message_history = ChatMessageHistory()

    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key="answer",
        chat_memory=message_history,
        return_messages=True,
    )

    # Create a chain that uses the Chroma vector store
    chain = ConversationalRetrievalChain.from_llm(
        ChatOllama(model="mistral"),
        chain_type="stuff",
        retriever=docsearch.as_retriever(),
        memory=memory,
        return_source_documents=True,
    )

    # Let the user know that the system is ready
    msg.content = f"Processing `{file.name}` done. You can now ask questions!"
    await msg.update()

    cl.user_session.set("chain", chain)


@cl.on_message
async def main(message: cl.Message):
    chain = cl.user_session.get("chain")  # type: ConversationalRetrievalChain
    cb = cl.AsyncLangchainCallbackHandler()

    res = await chain.ainvoke(message.content, callbacks=[cb])
    answer = res["answer"]
    source_documents = res["source_documents"]  # type: List[Document]

    text_elements = []  # type: List[cl.Text]

    if source_documents:
        for source_idx, source_doc in enumerate(source_documents):
            source_name = f"source_{source_idx}"
            # Create the text element referenced in the message
            text_elements.append(
                cl.Text(content=source_doc.page_content, name=source_name)
            )
        source_names = [text_el.name for text_el in text_elements]

        if source_names:
            answer += f"\nSources: {', '.join(source_names)}"
        else:
            answer += "\nNo sources found"

    await cl.Message(content=answer, elements=text_elements).send()

Upvotes: 0

Views: 530

Answers (1)

cottontail
cottontail

Reputation: 23321

TL;DR: Use the persist_directory kwarg.

First of all, just so we're on the same page, the below code is using langchain==0.1.20, langchain-community==0.0.38, langchain-core==0.1.52 and chromadb==0.5.0.

If you specify a persistent directory, a SQLite database corresponding to the vector store is created in that directory. So in your code, when you create a vector store using Chroma.from_texts(), you can specify the persistent directory and the collection name (for that file) so that the next time the same file is uploaded, instead of creating the vector store again, you can pull up the vector store corresponding to the uploaded file using the persistent directory and collection name.

Now to actually find the already created collection, we need to search for it in the database. Chroma() returns a ChromaDB vector store and you can use ._client to access the client that connects to it and using the client, we can access the database itself. Once we access the database, we can get the list of all collections via .list_collections() and get the names that way.

The full code is as follows. N.B. I used OpenAIEmbeddings and ChatOpenAI from langchain-openai==0.1.5 because I don't have Ollama but I suspect it works the same either way.

from typing import List
import pypdf
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import (
    ConversationalRetrievalChain,
)
from langchain.docstore.document import Document
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
import chainlit as cl

PERSIST_DIRECTORY = "vectorstore/"                         # <--- the directory to store the vector store

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)


@cl.on_chat_start
async def on_chat_start():

    files = None

    # Wait for the user to upload a file
    while files is None:
        files = await cl.AskFileMessage(
            content="Please upload a pdf file to begin!",
            accept=["application/pdf"],
            max_size_mb=20,
            timeout=180,
            max_files=2,
        ).send()

    file = files[0]
    print(file)

    msg = cl.Message(content=f"Processing `{file.name}`...")
    await msg.send()

    embeddings = OpenAIEmbeddings()

    # get chromadb client
    chroma_client = Chroma(persist_directory=PERSIST_DIRECTORY)._client
    # get collection names in the database
    collection_names = {col.name for col in chroma_client.list_collections()}
    # get the specific collection name corresponding to the uploaded file
    collection_name = file.name.split()[0].replace('-', '').replace('.', '')

    if collection_name in collection_names:
        docsearch = Chroma(
            embedding_function=embeddings,
            collection_name=collection_name,               # <--- get existing collection by name
            persist_directory=PERSIST_DIRECTORY,           # <--- from the persistent directory
        )
    else:
        # Read the PDF file
        pdf = pypdf.PdfReader(file.path)
        pdf_text = ""
        for page in pdf.pages:
            pdf_text += page.extract_text()

        # Split the text into chunks
        texts = text_splitter.split_text(pdf_text)
        print(texts)

        # # Create a metadata for each chunk
        metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]

        # # Create a Chroma vector store
        docsearch = await Chroma.afrom_texts(              # <--- Chroma has its native async method
            texts, 
            embeddings, 
            metadatas=metadatas,
            collection_name=collection_name,               # <--- create vector store under the collection
            persist_directory=PERSIST_DIRECTORY,           # <--- name in the persistent directory
        )

    message_history = ChatMessageHistory()

    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key="answer",
        chat_memory=message_history,
        return_messages=True,
    )

    # Create a chain that uses the Chroma vector store
    chain = ConversationalRetrievalChain.from_llm(
        ChatOpenAI(),
        chain_type="stuff",
        retriever=docsearch.as_retriever(),
        memory=memory,
        return_source_documents=True,
    )

    # Let the user know that the system is ready
    msg.content = f"Processing `{file.name}` done. You can now ask questions!"
    await msg.update()

    cl.user_session.set("chain", chain)


@cl.on_message
async def main(message: cl.Message):
    chain = cl.user_session.get("chain")  # type: ConversationalRetrievalChain
    cb = cl.AsyncLangchainCallbackHandler()

    res = await chain.ainvoke(message.content, callbacks=[cb])
    answer = res["answer"]
    source_documents = res["source_documents"]  # type: List[Document]

    text_elements = []  # type: List[cl.Text]

    if source_documents:
        for source_idx, source_doc in enumerate(source_documents):
            source_name = f"source_{source_idx}"
            # Create the text element referenced in the message
            text_elements.append(
                cl.Text(content=source_doc.page_content, name=source_name)
            )
        source_names = [text_el.name for text_el in text_elements]

        if source_names:
            answer += f"\nSources: {', '.join(source_names)}"
        else:
            answer += "\nNo sources found"

    await cl.Message(content=answer, elements=text_elements).send()

Upvotes: 0

Related Questions