Reputation: 555
I am using the PartentDocumentRetriever from Langchain. Now I first want to build my vector database and then want to retrieve stuff.
Here is my file that builds the database:
# =========================
# Module: Vector DB Build
# =========================
import box
import yaml
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.storage import InMemoryStore
from langchain.retrievers import ParentDocumentRetriever
from langchain.vectorstores import Chroma
# Import config vars
with open('config/config.yml', 'r', encoding='utf8') as ymlfile:
cfg = box.Box(yaml.safe_load(ymlfile))
# Build vector database
def run_db_build():
loader = DirectoryLoader(cfg.DATA_PATH,
glob='*.pdf',
loader_cls=PyPDFLoader)
documents = loader.load()
embeddings = HuggingFaceEmbeddings(model_name=cfg.EMBEDDING_MODEL_NAME,
model_kwargs={'device': 'mps'}, encode_kwargs={'device': 'mps', 'batch_size': 32})
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
store = InMemoryStore()
vectorstore = Chroma(collection_name="split_parents", embedding_function=embeddings,
persist_directory="chroma_db/")
big_chunks_retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
big_chunks_retriever.add_documents(documents)
if __name__ == "__main__":
run_db_build()
So I am saving the Chroma Database in the folder "chroma_db". However I want to save PartentDocumentRetriever (big_chunk_objects) with the added documents to use it later when building a RetrievalQa chain. So how do I load "big_chunk_objects" in the following code?
def build_retrieval_qa(llm, prompt, vectordb):
chain_type_kwargs={
#"verbose": True,
"prompt": prompt,
"memory": ConversationBufferMemory(
memory_key="history",
input_key="question")}
dbqa = RetrievalQA.from_chain_type(llm=llm,
chain_type='stuff',
retriever="HOW TO SET PARENTDOCUMENTRETRIEVER HERE?",
return_source_documents=cfg.RETURN_SOURCE_DOCUMENTS,
chain_type_kwargs=chain_type_kwargs,
)
return dbqa
Upvotes: 3
Views: 5803
Reputation: 11
EDIT:
save_to_pickle
function, Mar 13th 2024Alternatively, you can get the store
in the docstore
and save it into a pickle file using the below code, as it seems to be the only valuable part in the docstore
for my project with MultiVectorRetriever
.
import pickle
def save_to_pickle(obj, filename):
with open(filename, "wb") as file:
pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)
def load_from_pickle(filename):
with open(filename, "rb") as file:
return pickle.load(file)
save_to_pickle(retriever.byte_store.store, docstore_path)
For building:
def contruct_vectorstore(
doc_list,
hypo_list,
save_path,
embedding_model="text-embedding-3-large",
include_product_type=False,
):
os.makedirs(save_path, exist_ok=True)
vector_store_path = os.path.join(save_path, "chroma/")
docstore_path = os.path.join(save_path, "docstore.pkl")
vectorstore = Chroma(
embedding_function=OpenAIEmbeddings(model=embedding_model),
persist_directory=vector_store_path,
)
store = InMemoryByteStore()
id_key = "doc_id"
# The retriever (empty to start)
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
byte_store=store,
id_key=id_key,
)
doc_ids = [str(uuid.uuid4()) for _ in doc_list]
ingest_data = []
for i, question_list in enumerate(hypo_list):
for question in question_list:
if include_product_type:
product_type = doc_list[i].metadata["product_type"]
augment_contet = question + " " + product_type
else:
augment_contet = question
ingest_data.append(
Document(page_content=augment_contet, metadata={id_key: doc_ids[i]})
)
retriever.vectorstore.add_documents(ingest_data)
retriever.docstore.mset(list(zip(doc_ids, doc_list)))
# Save the vectorstore and docstore to disk
retriever.vectorstore.persist()
save_to_pickle(retriever.byte_store.store, docstore_path)
return retriever
For loading:
def load_retriever(load_path, embedding_model):
"""Loads the vector store and document store, initializing the retriever."""
vector_store_path = os.path.join(load_path, "chroma")
db3 = Chroma(
persist_directory=vector_store_path, embedding_function=embedding_model
)
store_dict = load_from_pickle(os.path.join(load_path, "docstore.pkl"))
store = InMemoryByteStore()
store.mset(list(store_dict.items()))
retriever = MultiVectorRetriever(
vectorstore=db3,
byte_store=store,
id_key="doc_id",
search_kwargs={"k": 4},
)
return retriever
Upvotes: 1
Reputation: 297
EDIT: SQLDocStore
is now available in LangChain (0.1.4) https://github.com/langchain-ai/langchain/releases/tag/v0.1.4
--
An approach that uses a persistant remote docstore would be to use an SQLDocStore
instead of InMemoryStore
.
You can adapt your code by replacing InMemoryStore
:
COLLECTION_NAME = "test"
CONNECTION_STRING = "postgresql+psycopg2://user:pass@localhost:5432/db"
store = SQLDocStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
)
vectorstore = Chroma(collection_name="split_parents", embedding_function=embeddings,
persist_directory="chroma_db/")
big_chunks_retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
big_chunks_retriever.add_documents(documents)
For the SQLDocStore
implementation, I wrote this PR within the LangChain project and might be merged soon, you can find the code (extracted from the PR) to use SQLDocStore
before the pull request is officially accepted here:
"""SQL storage that persists data in a SQL database
and supports data isolation using collections."""
from __future__ import annotations
import uuid
from typing import Any, Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar
import sqlalchemy
from sqlalchemy import JSON, UUID
from sqlalchemy.orm import Session, relationship
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain_core.documents import Document
from langchain_core.load import Serializable, dumps, loads
from langchain_core.stores import BaseStore
V = TypeVar("V")
ITERATOR_WINDOW_SIZE = 1000
Base = declarative_base() # type: Any
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
class BaseModel(Base):
"""Base model for the SQL stores."""
__abstract__ = True
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
_classes: Any = None
def _get_storage_stores() -> Any:
global _classes
if _classes is not None:
return _classes
class CollectionStore(BaseModel):
"""Collection store."""
__tablename__ = "langchain_storage_collection"
name = sqlalchemy.Column(sqlalchemy.String)
cmetadata = sqlalchemy.Column(JSON)
items = relationship(
"ItemStore",
back_populates="collection",
passive_deletes=True,
)
@classmethod
def get_by_name(
cls, session: Session, name: str
) -> Optional["CollectionStore"]:
# type: ignore
return session.query(cls).filter(cls.name == name).first()
@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True if the collection was created.
""" # noqa: E501
created = False
collection = cls.get_by_name(session, name)
if collection:
return collection, created
collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created
class ItemStore(BaseModel):
"""Item store."""
__tablename__ = "langchain_storage_items"
collection_id = sqlalchemy.Column(
UUID(as_uuid=True),
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship(CollectionStore, back_populates="items")
content = sqlalchemy.Column(sqlalchemy.String, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
_classes = (ItemStore, CollectionStore)
return _classes
class SQLBaseStore(BaseStore[str, V], Generic[V]):
"""SQL storage
Args:
connection_string: SQL connection string that will be passed to SQLAlchemy.
collection_name: The name of the collection to use. (default: langchain)
NOTE: Collections are useful to isolate your data in a given a database.
This is not the name of the table, but the name of the collection.
The tables will be created when initializing the store (if not exists)
So, make sure the user has the right permissions to create tables.
pre_delete_collection: If True, will delete the collection if it exists.
(default: False). Useful for testing.
engine_args: SQLAlchemy's create engine arguments.
Example:
.. code-block:: python
from langchain_community.storage import SQLDocStore
from langchain_community.embeddings.openai import OpenAIEmbeddings
# example using an SQLDocStore to store Document objects for
# a ParentDocumentRetriever
CONNECTION_STRING = "postgresql+psycopg2://user:pass@localhost:5432/db"
COLLECTION_NAME = "state_of_the_union_test"
docstore = SQLDocStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
)
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
vectorstore = ...
retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=docstore,
child_splitter=child_splitter,
)
# example using an SQLStrStore to store strings
# same example as in "InMemoryStore" but using SQL persistence
store = SQLDocStore(
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING,
)
store.mset([('key1', 'value1'), ('key2', 'value2')])
store.mget(['key1', 'key2'])
# ['value1', 'value2']
store.mdelete(['key1'])
list(store.yield_keys())
# ['key2']
list(store.yield_keys(prefix='k'))
# ['key2']
# delete the COLLECTION_NAME collection
docstore.delete_collection()
"""
def __init__(
self,
connection_string: str,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
collection_metadata: Optional[dict] = None,
pre_delete_collection: bool = False,
connection: Optional[sqlalchemy.engine.Connection] = None,
engine_args: Optional[dict[str, Any]] = None,
) -> None:
self.connection_string = connection_string
self.collection_name = collection_name
self.collection_metadata = collection_metadata
self.pre_delete_collection = pre_delete_collection
self.engine_args = engine_args or {}
# Create a connection if not provided, otherwise use the provided connection
self._conn = connection if connection else self.__connect()
self.__post_init__()
def __post_init__(
self,
) -> None:
"""Initialize the store."""
ItemStore, CollectionStore = _get_storage_stores()
self.CollectionStore = CollectionStore
self.ItemStore = ItemStore
self.__create_tables_if_not_exists()
self.__create_collection()
def __connect(self) -> sqlalchemy.engine.Connection:
engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
conn = engine.connect()
return conn
def __create_tables_if_not_exists(self) -> None:
with self._conn.begin():
Base.metadata.create_all(self._conn)
def __create_collection(self) -> None:
if self.pre_delete_collection:
self.delete_collection()
with Session(self._conn) as session:
self.CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)
def delete_collection(self) -> None:
with Session(self._conn) as session:
collection = self.__get_collection(session)
if not collection:
return
session.delete(collection)
session.commit()
def __get_collection(self, session: Session) -> Any:
return self.CollectionStore.get_by_name(session, self.collection_name)
def __del__(self) -> None:
if self._conn:
self._conn.close()
def __serialize_value(self, obj: V) -> str:
if isinstance(obj, Serializable):
return dumps(obj)
return obj
def __deserialize_value(self, obj: V) -> str:
try:
return loads(obj)
except Exception:
return obj
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
"""Get the values associated with the given keys.
Args:
keys (Sequence[str]): A sequence of keys.
Returns:
A sequence of optional values associated with the keys.
If a key is not found, the corresponding value will be None.
"""
with Session(self._conn) as session:
collection = self.__get_collection(session)
items = (
session.query(self.ItemStore.content, self.ItemStore.custom_id)
.where(
sqlalchemy.and_(
self.ItemStore.custom_id.in_(keys),
self.ItemStore.collection_id == (collection.uuid),
)
)
.all()
)
ordered_values = {key: None for key in keys}
for item in items:
v = item[0]
val = self.__deserialize_value(v) if v is not None else v
k = item[1]
ordered_values[k] = val
return [ordered_values[key] for key in keys]
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs.
Returns:
None
"""
with Session(self._conn) as session:
collection = self.__get_collection(session)
if not collection:
raise ValueError("Collection not found")
for id, item in key_value_pairs:
content = self.__serialize_value(item)
item_store = self.ItemStore(
content=content,
custom_id=id,
collection_id=collection.uuid,
)
session.add(item_store)
session.commit()
def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[str]): A sequence of keys to delete.
"""
with Session(self._conn) as session:
collection = self.__get_collection(session)
if not collection:
raise ValueError("Collection not found")
if keys is not None:
stmt = sqlalchemy.delete(self.ItemStore).where(
sqlalchemy.and_(
self.ItemStore.custom_id.in_(keys),
self.ItemStore.collection_id == (collection.uuid),
)
)
session.execute(stmt)
session.commit()
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
"""Get an iterator over keys that match the given prefix.
Args:
prefix (str, optional): The prefix to match. Defaults to None.
Returns:
Iterator[str]: An iterator over keys that match the given prefix.
"""
with Session(self._conn) as session:
collection = self.__get_collection(session)
start = 0
while True:
stop = start + ITERATOR_WINDOW_SIZE
query = session.query(self.ItemStore.custom_id).where(
self.ItemStore.collection_id == (collection.uuid)
)
if prefix is not None:
query = query.filter(self.ItemStore.custom_id.startswith(prefix))
items = query.slice(start, stop).all()
if len(items) == 0:
break
for item in items:
yield item[0]
start += ITERATOR_WINDOW_SIZE
SQLDocStore = SQLBaseStore[Document]
SQLStrStore = SQLBaseStore[str]
If you want more details, I've written an article on the SQL storage.
Upvotes: 1
Reputation: 11
I ran into almost the exact same issue. I fixed it by using pickle
as explained in this popular stackoverflow post.
Somewhere in your db_build file, you should add:
def save_object(obj, filename):
with open(filename, 'wb') as outp: # Overwrites any existing file.
pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL)
Then at the end of said file, save the retriever to a local file by adding the following line:
save_object(big_chunks_retriever, 'retriever.pkl')
Now in the other file, load the retriever by adding:
with open('retriever.pkl', 'rb') as inp:
big_chunks_retriever = pickle.load(inp)
And finally define your build_retrieval_qa() as follows:
def build_retrieval_qa(llm, prompt):
chain_type_kwargs={
#"verbose": True,
"prompt": prompt,
"memory": ConversationBufferMemory(
memory_key="history",
input_key="question")}
dbqa = RetrievalQA.from_chain_type(llm=llm,
chain_type='stuff',
retriever=big_chunks_retriever,
return_source_documents=cfg.RETURN_SOURCE_DOCUMENTS,
chain_type_kwargs=chain_type_kwargs,
)
return dbqa
In this snippet I removed vectordb as a variable, as it isn't used anywhere anymore
Hope this helped you with your issue!
Upvotes: 1