Reputation: 971
I have a simple RAG app and cannot figure out how to store memory with streaming. Should save_context
be part of the chain? Or do I have to handle it using some callback?
At the end of the example is answer_chain
, where the last step is skipped. I believe it should be something at the end, but I cannot figure out what. I want to run a callback when streaming is finished.
Also, I split the chain into two steps, as when there is one big streaming chain, it sends documents and so on to the stout, which does not make sense, I only want messages. Is it the proper way to handle it with two separate chains?
Any ideas?
import uuid
from typing import Iterator
import dotenv
from langchain_core.messages import get_buffer_string
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, format_document
from langchain_core.runnables import RunnableLambda, RunnablePassthrough, RunnableParallel
from langchain_core.runnables.utils import Output
from document_index.vector import get_retriever
from operator import itemgetter
from memory import get_memory
from model import get_model
dotenv.load_dotenv()
model = get_model()
condense_question_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_question_prompt)
initial_prompt = """
You are helpful AI assistant.
Answer the question based only on the context below.
### Context start ###
{context}
### Context end ###
Question: {question}
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(initial_prompt)
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
retriever = get_retriever()
def _get_memory_with_session_id(session_id):
return get_memory(session_id)
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
def search(session_id, query) -> Iterator[Output]:
memory = _get_memory_with_session_id(session_id)
def _save_context(inputs, answer):
memory.save_context(inputs, {"answer": answer})
loaded_memory = RunnablePassthrough.assign(
chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
)
standalone_question = {
"standalone_question": {
"question": lambda x: x["question"],
"chat_history": lambda x: get_buffer_string(x["chat_history"]),
}
| CONDENSE_QUESTION_PROMPT
| model
| StrOutputParser()
}
retrieved_documents = {
"docs": itemgetter("standalone_question") | retriever,
"question": lambda x: x["standalone_question"],
}
preparation_chain = loaded_memory | standalone_question | retrieved_documents
memory.load_memory_variables({})
inputs = {"question": query}
docs = preparation_chain.invoke(inputs)
answer_chain = (
{"docs": RunnablePassthrough()}
| {
"context": lambda x: _combine_documents(x["docs"]),
"question": itemgetter("question"),
}
| ANSWER_PROMPT
| model
| StrOutputParser()
# | RunnableLambda(_save_context, ????query_argument, ????MODEL_ANSWER)
)
return answer_chain.stream(docs)
if __name__ == "__main__":
session_id = str(uuid.uuid4())
query = "Where to buy beer?"
for result in search(session_id, query):
print(result, end="")
Upvotes: 1
Views: 626
Reputation: 151
There's an open issue on this case, you might have to find workarounds. I came up with my own solution which constist in creating a custom runnable lambda which lets the input go through the stream and call the lambda with the collected output. In there you can save the memory context.
For some reason, memory state does not seem to preserve between chain invocations. This is what I've came up with, it's early crappy stuff but helps getting where you want:
class RunnableCollector(RunnableLambda):
def _transform(
self,
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
final: Input
got_first_val = False
for ichunk in input:
yield cast(Output, ichunk)
if not got_first_val:
final = ichunk
got_first_val = True
else:
try:
final = final + ichunk # type: ignore[operator]
except TypeError:
final = ichunk
call_func_with_variable_args(
self.func, cast(Input, final), config, run_manager, **kwargs
)
class RunnableFilter(RunnableLambda):
def __init__(self, filter: Callable[[Input], bool], **kwargs) -> None:
super().__init__(func=lambda _: None, **kwargs)
self.filter = filter
def _transform(
self,
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
for ichunk in input:
if self.filter(ichunk):
yield ichunk
class RunnableMap(RunnableLambda):
def __init__(self, mapping: Callable[[Input], Output], **kwargs) -> None:
super().__init__(func=lambda _: None, **kwargs)
self.mapping = mapping
def _transform(
self,
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
for ichunk in input:
yield self.mapping(ichunk)
def setup_chain(
model="gpt-3.5-turbo",
) -> RunnableSerializable:
memory = ConversationBufferMemory(memory_key="history", return_messages=True)
def save_into_mem(input_output: Dict[str, Any]):
message = input_output.pop("output")
memory.save_context(input_output, {"output": message.content})
print("\n\n2.\t", memory.load_memory_variables({}))
chain = (
RunnablePassthrough.assign(
history=RunnableLambda(memory.load_memory_variables)
| itemgetter("history"),
dummy=RunnableLambda(lambda x: print("input?", x)),
)
| ChatPromptTemplate.from_messages(
[
("system", SYSTEM_MESSAGE),
MessagesPlaceholder(variable_name="history"),
("user", "{input}"),
]
)
| ChatOpenAI(
model=model,
temperature=0,
streaming=True,
)
)
chain = (
RunnablePassthrough.assign(
output=chain,
)
| RunnableFilter(
filter=lambda chunk: type(chunk) is AddableDict
and "output" in chunk
or "input" in chunk
)
| RunnableCollector(save_into_mem)
| RunnableFilter(
filter=lambda chunk: type(chunk) is AddableDict
and "output" in chunk
and not "input" in "chunk"
)
| RunnableMap(mapping=lambda chunk: cast(Output, chunk["output"]))
)
return chain
Upvotes: 0