SparkOn
SparkOn

Reputation: 8946

How to stream LangChain, LangGraph's final generation

How do we stream LangGraph's final generation using vercel's AI sdk? It works pretty well with LangChain LCEL as explain in this blog. But how do we do this using LangGraph and vercel ai?

I am getting error TypeError: stream.pipeThrough is not a function while doing LangChainAdapter.toDataStreamResponse(final_generation) in handling the graph's stream

When user clicks submit the below POST is invoked, which in turn invokes the graph flow. This is where the error occurs.

//src/routes/api/chat
import { LangChainAdapter } from 'ai';
import type { RequestHandler } from './$types';
import type { Message } from 'ai/svelte';
import { Workflow } from '$lib/server/graph/workflow';

//server endpoint for chatGpt Stream Chat
export const POST: RequestHandler = async ({ request }) => {

    const { messages }: { messages: Message[] } = await request.json();
    let final_generation =  null;
    const eventStream = await Workflow.getCompiledStateGraph().streamEvents({'question': messages.pop(), 'chat_history': messages}, { version: "v2"});
    for await (const { event, tags, data } of eventStream) {
        if (event === "on_chat_model_stream") {
            console.log("tags:", tags)
            console.log("data", data);
            console.log("event", event);
            if (data.chunk.content) {
                final_generation = data.chunk
            }
        }
    }
    return LangChainAdapter.toDataStreamResponse(final_generation);
};

A simple CompiledStateGraph

export class Workflow {
    // @ts-ignore
    private static COMPILED_STATE_GRAPH: CompiledStateGraph | null = null;
    
    private constructor() {}
    
    public static getCompiledStateGraph() {
        if (!Workflow.COMPILED_STATE_GRAPH) {
            const graph = new StateGraph(State)
            .addNode("retrieve", retrieveDocuments)
            .addNode("llm_search", generate)
            .addConditionalEdges(START, routeQuestion)
            .addEdge("llm_search", END)
            .addEdge("retrieve", END);
            Workflow.COMPILED_STATE_GRAPH  = graph.compile();
        }
        return Workflow.COMPILED_STATE_GRAPH;
    }
}

generate node

import { State } from '$lib/server/graph/state';
import { LLMClient } from '$lib/server/llm-client';
import { StringOutputParser } from '@langchain/core/output_parsers';
import { ChatPromptTemplate } from '@langchain/core/prompts';

export const generate = async (state: typeof State.State): Promise<Partial<typeof State.State>> => {
    console.log("---LLM Inference---");
    const PROMPT_TEMPLATE = 'You are a helpful assistant!';
    const prompt = ChatPromptTemplate.fromMessages([
        ['system', PROMPT_TEMPLATE],
        ['human', "{question}"],
    ]);
    const routeUserQuestionChain = prompt.pipe(LLMClient.getClient()).pipe(new StringOutputParser());
    const stream = await routeUserQuestionChain.invoke({question: state.question});
    return { generation: stream };
};

State

import type { DocumentInterface } from '@langchain/core/documents';
import { Annotation, MessagesAnnotation } from '@langchain/langgraph';

export const GraphState = Annotation.Root({
    documents: Annotation<DocumentInterface[]>({
        reducer: (x, y) => y ?? x ?? []
    }),
    question: Annotation<string>({
        reducer: (x, y) => y ?? x ?? ''
    }),
    generation: Annotation<string>({
        reducer: (x, y) => y ?? x,
        default: () => ''
    }),
    ...MessagesAnnotation.spec
});

The input binding, handling the user message is done by useChat() of vercel ai

Rendering part, here the HumanInput component binds to the input and submits the user query to handleSubmit which in turn invokes our previous POST server function.

<script lang="ts">
    import HumanInput from "$lib/components/HumanInput.svelte";
    import MaxWidthWrapper from '$lib/components/MaxWidthWrapper.svelte';
    import DisplayMessages from "$lib/components/DisplayMessages.svelte";
    import {useChat} from '@ai-sdk/svelte';
    const { input, handleSubmit, messages } = useChat();
</script>

<div class="flex flex-col h-screen">
    <div class="flex-grow overflow-hidden">
        <MaxWidthWrapper class_="h-full flex flex-col">
            <DisplayMessages {messages} />
            <HumanInput {input} {handleSubmit}/>
        </MaxWidthWrapper>
    </div>
</div>

Upvotes: 1

Views: 606

Answers (1)

SparkOn
SparkOn

Reputation: 8946

The main concept we need to understand here is how Vercel AI and LangChain handles the messages. While AI SDK understands Message from ai package, LangChain deals with subtypes of BaseMessage from @langchain/core/messages package.

The trick to solving this issue was to translate the message between these two formats.

//src/api/chat/+server.ts
import { LangChainAdapter } from 'ai';
import type { Message } from 'ai/svelte';
import { Workflow } from '$lib/server/graph/workflow';
import { convertLangChainMessageToVercelMessage, convertVercelMessageToLangChainMessage } from '$lib/utils/utility';

export const POST = async ({ request, params }) => {
    const config = { configurable: { thread_id: params.id}, version: "v2" };
    const messages: { messages: Message[] } = await request.json();
    const userQuery = messages.messages[messages.messages.length - 1].content;
    
    let history = (messages.messages ?? [])
    .slice(0, -1)
    .filter(
        (message: Message) =>
            message.role === 'user' || message.role === 'assistant'
    )
    .map(convertVercelMessageToLangChainMessage);
    
    let compiledStateGraph = Workflow.getCompiledStateGraph();
    const stream = await compiledStateGraph.streamEvents({question: userQuery,  messages: history}, config);
    const transformStream = new ReadableStream({
        async start(controller) {
            for await (const { event, data, tags } of stream) {
                if (event === 'on_chat_model_stream') {
                    if (!!data.chunk.content  && tags.includes("llm_inference")) {
                        const aiMessage = convertLangChainMessageToVercelMessage(data.chunk);
                        controller.enqueue(aiMessage);
                    }
                }
            }
            controller.close();
        }
    });
    return LangChainAdapter.toDataStreamResponse(transformStream);
};

Before we invoke the graph flow, I am extracting the current user query from the message array of useChat() and then converting the rest of the messages to langchain understandable format using convertVercelMessageToLangChainMessage function. Similarly once i receive the AIMesssageChunks from LangChain streams, I am converting them back to vercel ai understandable format using convertLangChainMessageToVercelMessage function before returning the stream back for useChat() to handle the new messages.

import type { Message } from 'ai/svelte';
import { AIMessage, BaseMessage, ChatMessage, HumanMessage } from '@langchain/core/messages';

/**
 * Converts a Vercel message to a LangChain message.
 * @param message - The message to convert.
 * @returns The converted LangChain message.
 */
export const convertVercelMessageToLangChainMessage = (message: Message): BaseMessage => {
  switch (message.role) {
    case 'user':
      return new HumanMessage({ content: message.content });
    case 'assistant':
      return new AIMessage({ content: message.content });
    default:
      return new ChatMessage({ content: message.content, role: message.role });
  }
};

/**
 * Converts a LangChain message to a Vercel message.
 * @param message - The message to convert.
 * @returns The converted Vercel message.
 */
export const convertLangChainMessageToVercelMessage = (message: BaseMessage) => {
  switch (message.getType()) {
    case 'human':
      return { content: message.content, role: 'user' };
    case 'ai':
      return {
        content: message.content,
        role: 'assistant',
        tool_calls: (message as AIMessage).tool_calls
      };
    default:
      return { content: message.content, role: message._getType() };
  }
};

Also notice if (!!data.chunk.content && tags.includes("llm_inference")) this is how we can filter the last generation. As during the last node execution of graph we can tag the LLM with configs(tags) which we can later use to get the result of that node's execution.

export const generate = async (state: typeof GraphState.State) => {
    console.log("---LLM Inference---");
    
    const PROMPT_TEMPLATE = 'You are a helpful assistant! Please answer the user query. Use the chat history to provide context.';
    
    const prompt = ChatPromptTemplate.fromMessages([
        ['system', PROMPT_TEMPLATE],
        ['human', "{question}"],
        ['human', "Chat History: {messages}"],
    ]);
    
    const inferenceChain = prompt.pipe(LLMClient.getClient().withConfig({ tags: ["llm_inference"]}));
    const generation = await inferenceChain.invoke(state);
    return { messages: [generation] };
};

Upvotes: 1

Related Questions