Maxl Gemeinderat
Maxl Gemeinderat

Reputation: 555

Langchain ParentDocumetRetriever: Save and load

I am using the PartentDocumentRetriever from Langchain. Now I first want to build my vector database and then want to retrieve stuff.

Here is my file that builds the database:

# =========================
#  Module: Vector DB Build
# =========================
import box
import yaml
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.storage import InMemoryStore
from langchain.retrievers import ParentDocumentRetriever
from langchain.vectorstores import Chroma

# Import config vars
with open('config/config.yml', 'r', encoding='utf8') as ymlfile:
    cfg = box.Box(yaml.safe_load(ymlfile))

# Build vector database
def run_db_build():
    loader = DirectoryLoader(cfg.DATA_PATH,
                             glob='*.pdf',
                             loader_cls=PyPDFLoader)
    documents = loader.load()
    
    embeddings = HuggingFaceEmbeddings(model_name=cfg.EMBEDDING_MODEL_NAME,
                                       model_kwargs={'device': 'mps'}, encode_kwargs={'device': 'mps', 'batch_size': 32})
    
    parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
    child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
    store = InMemoryStore()

    vectorstore = Chroma(collection_name="split_parents", embedding_function=embeddings,
                         persist_directory="chroma_db/") 
    big_chunks_retriever = ParentDocumentRetriever(
        vectorstore=vectorstore,
        docstore=store,
        child_splitter=child_splitter,
        parent_splitter=parent_splitter,
    )
    big_chunks_retriever.add_documents(documents)

if __name__ == "__main__":
    run_db_build()

So I am saving the Chroma Database in the folder "chroma_db". However I want to save PartentDocumentRetriever (big_chunk_objects) with the added documents to use it later when building a RetrievalQa chain. So how do I load "big_chunk_objects" in the following code?

def build_retrieval_qa(llm, prompt, vectordb):
    chain_type_kwargs={
        #"verbose": True,
        "prompt": prompt,
        "memory": ConversationBufferMemory(
            memory_key="history",
            input_key="question")}
    
    dbqa = RetrievalQA.from_chain_type(llm=llm,
                                       chain_type='stuff',
                                       retriever="HOW TO SET PARENTDOCUMENTRETRIEVER HERE?",                                       
                                       return_source_documents=cfg.RETURN_SOURCE_DOCUMENTS,
                                       chain_type_kwargs=chain_type_kwargs,
                                      )
return dbqa

Upvotes: 3

Views: 5803

Answers (3)

user23465155
user23465155

Reputation: 11

EDIT:

  1. Add save_to_pickle function, Mar 13th 2024
  2. Add the building function, April 9th 2024

Alternatively, you can get the store in the docstore and save it into a pickle file using the below code, as it seems to be the only valuable part in the docstore for my project with MultiVectorRetriever.


import pickle

def save_to_pickle(obj, filename):
    with open(filename, "wb") as file:
        pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

def load_from_pickle(filename):
    with open(filename, "rb") as file:
        return pickle.load(file)

save_to_pickle(retriever.byte_store.store, docstore_path)

For building:

def contruct_vectorstore(
    doc_list,
    hypo_list,
    save_path,
    embedding_model="text-embedding-3-large",
    include_product_type=False,
):
    os.makedirs(save_path, exist_ok=True)

    vector_store_path = os.path.join(save_path, "chroma/")
    docstore_path = os.path.join(save_path, "docstore.pkl")

    vectorstore = Chroma(
        embedding_function=OpenAIEmbeddings(model=embedding_model),
        persist_directory=vector_store_path,
    )

    store = InMemoryByteStore()
    id_key = "doc_id"
    # The retriever (empty to start)
    retriever = MultiVectorRetriever(
        vectorstore=vectorstore,
        byte_store=store,
        id_key=id_key,
    )

    doc_ids = [str(uuid.uuid4()) for _ in doc_list]

    ingest_data = []

    for i, question_list in enumerate(hypo_list):
        for question in question_list:
            if include_product_type:
                product_type = doc_list[i].metadata["product_type"]
                augment_contet = question + " " + product_type
            else:
                augment_contet = question

            ingest_data.append(
                Document(page_content=augment_contet, metadata={id_key: doc_ids[i]})
            )

    retriever.vectorstore.add_documents(ingest_data)
    retriever.docstore.mset(list(zip(doc_ids, doc_list)))

    # Save the vectorstore and docstore to disk
    retriever.vectorstore.persist()
    save_to_pickle(retriever.byte_store.store, docstore_path)

    return retriever

For loading:

def load_retriever(load_path, embedding_model):
    """Loads the vector store and document store, initializing the retriever."""
    vector_store_path = os.path.join(load_path, "chroma")
    db3 = Chroma(
        persist_directory=vector_store_path, embedding_function=embedding_model
    )
    store_dict = load_from_pickle(os.path.join(load_path, "docstore.pkl"))

    store = InMemoryByteStore()
    store.mset(list(store_dict.items()))

    retriever = MultiVectorRetriever(
        vectorstore=db3,
        byte_store=store,
        id_key="doc_id",
        search_kwargs={"k": 4},
    )
    return retriever

Upvotes: 1

guibs35
guibs35

Reputation: 297

EDIT: SQLDocStore is now available in LangChain (0.1.4) https://github.com/langchain-ai/langchain/releases/tag/v0.1.4

--

An approach that uses a persistant remote docstore would be to use an SQLDocStore instead of InMemoryStore.

You can adapt your code by replacing InMemoryStore:

COLLECTION_NAME = "test"
CONNECTION_STRING = "postgresql+psycopg2://user:pass@localhost:5432/db"

store = SQLDocStore(
    collection_name=COLLECTION_NAME,
    connection_string=CONNECTION_STRING,
)

vectorstore = Chroma(collection_name="split_parents", embedding_function=embeddings,
                     persist_directory="chroma_db/") 
big_chunks_retriever = ParentDocumentRetriever(
    vectorstore=vectorstore,
    docstore=store,
    child_splitter=child_splitter,
    parent_splitter=parent_splitter,
)
big_chunks_retriever.add_documents(documents)

For the SQLDocStore implementation, I wrote this PR within the LangChain project and might be merged soon, you can find the code (extracted from the PR) to use SQLDocStore before the pull request is officially accepted here:

"""SQL storage that persists data in a SQL database
and supports data isolation using collections."""
from __future__ import annotations

import uuid
from typing import Any, Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar

import sqlalchemy
from sqlalchemy import JSON, UUID
from sqlalchemy.orm import Session, relationship

try:
    from sqlalchemy.orm import declarative_base
except ImportError:
    from sqlalchemy.ext.declarative import declarative_base

from langchain_core.documents import Document
from langchain_core.load import Serializable, dumps, loads
from langchain_core.stores import BaseStore

V = TypeVar("V")

ITERATOR_WINDOW_SIZE = 1000

Base = declarative_base()  # type: Any


_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"


class BaseModel(Base):
    """Base model for the SQL stores."""

    __abstract__ = True
    uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)


_classes: Any = None


def _get_storage_stores() -> Any:
    global _classes
    if _classes is not None:
        return _classes

    class CollectionStore(BaseModel):
        """Collection store."""

        __tablename__ = "langchain_storage_collection"

        name = sqlalchemy.Column(sqlalchemy.String)
        cmetadata = sqlalchemy.Column(JSON)

        items = relationship(
            "ItemStore",
            back_populates="collection",
            passive_deletes=True,
        )

        @classmethod
        def get_by_name(
            cls, session: Session, name: str
        ) -> Optional["CollectionStore"]:
            # type: ignore
            return session.query(cls).filter(cls.name == name).first()

        @classmethod
        def get_or_create(
            cls,
            session: Session,
            name: str,
            cmetadata: Optional[dict] = None,
        ) -> Tuple["CollectionStore", bool]:
            """
            Get or create a collection.
            Returns [Collection, bool] where the bool is True if the collection was created.
            """  # noqa: E501
            created = False
            collection = cls.get_by_name(session, name)
            if collection:
                return collection, created

            collection = cls(name=name, cmetadata=cmetadata)
            session.add(collection)
            session.commit()
            created = True
            return collection, created

    class ItemStore(BaseModel):
        """Item store."""

        __tablename__ = "langchain_storage_items"

        collection_id = sqlalchemy.Column(
            UUID(as_uuid=True),
            sqlalchemy.ForeignKey(
                f"{CollectionStore.__tablename__}.uuid",
                ondelete="CASCADE",
            ),
        )
        collection = relationship(CollectionStore, back_populates="items")

        content = sqlalchemy.Column(sqlalchemy.String, nullable=True)

        # custom_id : any user defined id
        custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)

    _classes = (ItemStore, CollectionStore)

    return _classes


class SQLBaseStore(BaseStore[str, V], Generic[V]):
    """SQL storage

    Args:
        connection_string: SQL connection string that will be passed to SQLAlchemy.
        collection_name: The name of the collection to use. (default: langchain)
            NOTE: Collections are useful to isolate your data in a given a database.
            This is not the name of the table, but the name of the collection.
            The tables will be created when initializing the store (if not exists)
            So, make sure the user has the right permissions to create tables.
        pre_delete_collection: If True, will delete the collection if it exists.
            (default: False). Useful for testing.
        engine_args: SQLAlchemy's create engine arguments.

    Example:
        .. code-block:: python

            from langchain_community.storage import SQLDocStore
            from langchain_community.embeddings.openai import OpenAIEmbeddings

            # example using an SQLDocStore to store Document objects for
            # a ParentDocumentRetriever
            CONNECTION_STRING = "postgresql+psycopg2://user:pass@localhost:5432/db"
            COLLECTION_NAME = "state_of_the_union_test"
            docstore = SQLDocStore(
                collection_name=COLLECTION_NAME,
                connection_string=CONNECTION_STRING,
            )
            child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
            vectorstore = ...

            retriever = ParentDocumentRetriever(
                vectorstore=vectorstore,
                docstore=docstore,
                child_splitter=child_splitter,
            )

            # example using an SQLStrStore to store strings
            # same example as in "InMemoryStore" but using SQL persistence
            store = SQLDocStore(
                collection_name=COLLECTION_NAME,
                connection_string=CONNECTION_STRING,
            )
            store.mset([('key1', 'value1'), ('key2', 'value2')])
            store.mget(['key1', 'key2'])
            # ['value1', 'value2']
            store.mdelete(['key1'])
            list(store.yield_keys())
            # ['key2']
            list(store.yield_keys(prefix='k'))
            # ['key2']

            # delete the COLLECTION_NAME collection
            docstore.delete_collection()
    """

    def __init__(
        self,
        connection_string: str,
        collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
        collection_metadata: Optional[dict] = None,
        pre_delete_collection: bool = False,
        connection: Optional[sqlalchemy.engine.Connection] = None,
        engine_args: Optional[dict[str, Any]] = None,
    ) -> None:
        self.connection_string = connection_string
        self.collection_name = collection_name
        self.collection_metadata = collection_metadata
        self.pre_delete_collection = pre_delete_collection
        self.engine_args = engine_args or {}
        # Create a connection if not provided, otherwise use the provided connection
        self._conn = connection if connection else self.__connect()
        self.__post_init__()

    def __post_init__(
        self,
    ) -> None:
        """Initialize the store."""
        ItemStore, CollectionStore = _get_storage_stores()
        self.CollectionStore = CollectionStore
        self.ItemStore = ItemStore
        self.__create_tables_if_not_exists()
        self.__create_collection()

    def __connect(self) -> sqlalchemy.engine.Connection:
        engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
        conn = engine.connect()
        return conn

    def __create_tables_if_not_exists(self) -> None:
        with self._conn.begin():
            Base.metadata.create_all(self._conn)

    def __create_collection(self) -> None:
        if self.pre_delete_collection:
            self.delete_collection()
        with Session(self._conn) as session:
            self.CollectionStore.get_or_create(
                session, self.collection_name, cmetadata=self.collection_metadata
            )

    def delete_collection(self) -> None:
        with Session(self._conn) as session:
            collection = self.__get_collection(session)
            if not collection:
                return
            session.delete(collection)
            session.commit()

    def __get_collection(self, session: Session) -> Any:
        return self.CollectionStore.get_by_name(session, self.collection_name)

    def __del__(self) -> None:
        if self._conn:
            self._conn.close()

    def __serialize_value(self, obj: V) -> str:
        if isinstance(obj, Serializable):
            return dumps(obj)
        return obj

    def __deserialize_value(self, obj: V) -> str:
        try:
            return loads(obj)
        except Exception:
            return obj

    def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
        """Get the values associated with the given keys.

        Args:
            keys (Sequence[str]): A sequence of keys.

        Returns:
            A sequence of optional values associated with the keys.
            If a key is not found, the corresponding value will be None.
        """
        with Session(self._conn) as session:
            collection = self.__get_collection(session)

            items = (
                session.query(self.ItemStore.content, self.ItemStore.custom_id)
                .where(
                    sqlalchemy.and_(
                        self.ItemStore.custom_id.in_(keys),
                        self.ItemStore.collection_id == (collection.uuid),
                    )
                )
                .all()
            )

        ordered_values = {key: None for key in keys}
        for item in items:
            v = item[0]
            val = self.__deserialize_value(v) if v is not None else v
            k = item[1]
            ordered_values[k] = val

        return [ordered_values[key] for key in keys]

    def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
        """Set the values for the given keys.

        Args:
            key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs.

        Returns:
            None
        """
        with Session(self._conn) as session:
            collection = self.__get_collection(session)
            if not collection:
                raise ValueError("Collection not found")
            for id, item in key_value_pairs:
                content = self.__serialize_value(item)
                item_store = self.ItemStore(
                    content=content,
                    custom_id=id,
                    collection_id=collection.uuid,
                )
                session.add(item_store)
            session.commit()

    def mdelete(self, keys: Sequence[str]) -> None:
        """Delete the given keys and their associated values.

        Args:
            keys (Sequence[str]): A sequence of keys to delete.
        """
        with Session(self._conn) as session:
            collection = self.__get_collection(session)
            if not collection:
                raise ValueError("Collection not found")
            if keys is not None:
                stmt = sqlalchemy.delete(self.ItemStore).where(
                    sqlalchemy.and_(
                        self.ItemStore.custom_id.in_(keys),
                        self.ItemStore.collection_id == (collection.uuid),
                    )
                )
                session.execute(stmt)
            session.commit()

    def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
        """Get an iterator over keys that match the given prefix.

        Args:
            prefix (str, optional): The prefix to match. Defaults to None.

        Returns:
            Iterator[str]: An iterator over keys that match the given prefix.
        """
        with Session(self._conn) as session:
            collection = self.__get_collection(session)
            start = 0
            while True:
                stop = start + ITERATOR_WINDOW_SIZE
                query = session.query(self.ItemStore.custom_id).where(
                    self.ItemStore.collection_id == (collection.uuid)
                )
                if prefix is not None:
                    query = query.filter(self.ItemStore.custom_id.startswith(prefix))
                items = query.slice(start, stop).all()

                if len(items) == 0:
                    break
                for item in items:
                    yield item[0]
                start += ITERATOR_WINDOW_SIZE


SQLDocStore = SQLBaseStore[Document]
SQLStrStore = SQLBaseStore[str]

If you want more details, I've written an article on the SQL storage.

Upvotes: 1

Dennis V
Dennis V

Reputation: 11

I ran into almost the exact same issue. I fixed it by using pickle as explained in this popular stackoverflow post.

Somewhere in your db_build file, you should add:

def save_object(obj, filename):
    with open(filename, 'wb') as outp:  # Overwrites any existing file.
        pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL)

Then at the end of said file, save the retriever to a local file by adding the following line:

save_object(big_chunks_retriever, 'retriever.pkl')

Now in the other file, load the retriever by adding:

with open('retriever.pkl', 'rb') as inp:
    big_chunks_retriever = pickle.load(inp)

And finally define your build_retrieval_qa() as follows:

def build_retrieval_qa(llm, prompt):
    chain_type_kwargs={
        #"verbose": True,
        "prompt": prompt,
        "memory": ConversationBufferMemory(
            memory_key="history",
            input_key="question")}
    
    dbqa = RetrievalQA.from_chain_type(llm=llm,
                                       chain_type='stuff',
                                       retriever=big_chunks_retriever,                                       
                                       return_source_documents=cfg.RETURN_SOURCE_DOCUMENTS,
                                       chain_type_kwargs=chain_type_kwargs,
                                      )
return dbqa

In this snippet I removed vectordb as a variable, as it isn't used anywhere anymore

Hope this helped you with your issue!

Upvotes: 1

Related Questions