rtyshyk
rtyshyk

Reputation: 971

Langchain : How to store memory with streaming?

I have a simple RAG app and cannot figure out how to store memory with streaming. Should save_context be part of the chain? Or do I have to handle it using some callback?

At the end of the example is answer_chain, where the last step is skipped. I believe it should be something at the end, but I cannot figure out what. I want to run a callback when streaming is finished.

Also, I split the chain into two steps, as when there is one big streaming chain, it sends documents and so on to the stout, which does not make sense, I only want messages. Is it the proper way to handle it with two separate chains?

Any ideas?

import uuid
from typing import Iterator

import dotenv
from langchain_core.messages import get_buffer_string
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, format_document
from langchain_core.runnables import RunnableLambda, RunnablePassthrough, RunnableParallel
from langchain_core.runnables.utils import Output

from document_index.vector import get_retriever
from operator import itemgetter
from memory import get_memory
from model import get_model

dotenv.load_dotenv()

model = get_model()
condense_question_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.

Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_question_prompt)

initial_prompt = """
You are helpful AI assistant.
Answer the question based only on the context below.

### Context start ###
{context}

### Context end ###

Question: {question}
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(initial_prompt)

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

retriever = get_retriever()


def _get_memory_with_session_id(session_id):
    return get_memory(session_id)


def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
    doc_strings = [format_document(doc, document_prompt) for doc in docs]
    return document_separator.join(doc_strings)


def search(session_id, query) -> Iterator[Output]:
    memory = _get_memory_with_session_id(session_id)

    def _save_context(inputs, answer):
        memory.save_context(inputs, {"answer": answer})

    loaded_memory = RunnablePassthrough.assign(
        chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
    )
    standalone_question = {
        "standalone_question": {
                                   "question": lambda x: x["question"],
                                   "chat_history": lambda x: get_buffer_string(x["chat_history"]),
                               }
                               | CONDENSE_QUESTION_PROMPT
                               | model
                               | StrOutputParser()
    }

    retrieved_documents = {
        "docs": itemgetter("standalone_question") | retriever,
        "question": lambda x: x["standalone_question"],
    }

    preparation_chain = loaded_memory | standalone_question | retrieved_documents
    memory.load_memory_variables({})
    inputs = {"question": query}
    docs = preparation_chain.invoke(inputs)

    answer_chain = (
            {"docs": RunnablePassthrough()}
            | {
                "context": lambda x: _combine_documents(x["docs"]),
                "question": itemgetter("question"),
            }
            | ANSWER_PROMPT
            | model
            | StrOutputParser()
          # | RunnableLambda(_save_context, ????query_argument, ????MODEL_ANSWER)
    )

    return answer_chain.stream(docs)


if __name__ == "__main__":
    session_id = str(uuid.uuid4())
    query = "Where to buy beer?"
    for result in search(session_id, query):
        print(result, end="")

Upvotes: 1

Views: 626

Answers (1)

Roberto Montalti
Roberto Montalti

Reputation: 151

There's an open issue on this case, you might have to find workarounds. I came up with my own solution which constist in creating a custom runnable lambda which lets the input go through the stream and call the lambda with the collected output. In there you can save the memory context.

For some reason, memory state does not seem to preserve between chain invocations. This is what I've came up with, it's early crappy stuff but helps getting where you want:

class RunnableCollector(RunnableLambda):
    def _transform(
        self,
        input: Iterator[Input],
        run_manager: CallbackManagerForChainRun,
        config: RunnableConfig,
        **kwargs: Any,
    ) -> Iterator[Output]:
        final: Input
        got_first_val = False
        for ichunk in input:
            yield cast(Output, ichunk)

            if not got_first_val:
                final = ichunk
                got_first_val = True
            else:
                try:
                    final = final + ichunk  # type: ignore[operator]
                except TypeError:
                    final = ichunk

        call_func_with_variable_args(
            self.func, cast(Input, final), config, run_manager, **kwargs
        )


class RunnableFilter(RunnableLambda):
    def __init__(self, filter: Callable[[Input], bool], **kwargs) -> None:
        super().__init__(func=lambda _: None, **kwargs)
        self.filter = filter

    def _transform(
        self,
        input: Iterator[Input],
        run_manager: CallbackManagerForChainRun,
        config: RunnableConfig,
        **kwargs: Any,
    ) -> Iterator[Output]:
        for ichunk in input:
            if self.filter(ichunk):
                yield ichunk


class RunnableMap(RunnableLambda):
    def __init__(self, mapping: Callable[[Input], Output], **kwargs) -> None:
        super().__init__(func=lambda _: None, **kwargs)
        self.mapping = mapping

    def _transform(
        self,
        input: Iterator[Input],
        run_manager: CallbackManagerForChainRun,
        config: RunnableConfig,
        **kwargs: Any,
    ) -> Iterator[Output]:
        for ichunk in input:
            yield self.mapping(ichunk)


def setup_chain(
    model="gpt-3.5-turbo",
) -> RunnableSerializable:
    memory = ConversationBufferMemory(memory_key="history", return_messages=True)

    def save_into_mem(input_output: Dict[str, Any]):
        message = input_output.pop("output")
        memory.save_context(input_output, {"output": message.content})
        print("\n\n2.\t", memory.load_memory_variables({}))

    chain = (
        RunnablePassthrough.assign(
            history=RunnableLambda(memory.load_memory_variables)
            | itemgetter("history"),
            dummy=RunnableLambda(lambda x: print("input?", x)),
        )
        | ChatPromptTemplate.from_messages(
            [
                ("system", SYSTEM_MESSAGE),
                MessagesPlaceholder(variable_name="history"),
                ("user", "{input}"),
            ]
        )
        | ChatOpenAI(
            model=model,
            temperature=0,
            streaming=True,
        )
    )

    chain = (
        RunnablePassthrough.assign(
            output=chain,
        )
        | RunnableFilter(
            filter=lambda chunk: type(chunk) is AddableDict
            and "output" in chunk
            or "input" in chunk
        )
        | RunnableCollector(save_into_mem)
        | RunnableFilter(
            filter=lambda chunk: type(chunk) is AddableDict
            and "output" in chunk
            and not "input" in "chunk"
        )
        | RunnableMap(mapping=lambda chunk: cast(Output, chunk["output"]))
    )

    return chain

Upvotes: 0

Related Questions