jereesh thomas
jereesh thomas

Reputation: 171

Lang graph human in the loop with check pointer

I am trying to build an agentic AI application using Lang Graph, I have a FAST API which accepts some parameters using which I create a Graph, I can use a check pointer as well. This graph is created inside a Graph creator class and used a local variable. Once the graph executed at specific step, I want to get Human feedback, so I interrupt the graph at a specific stage. I have a thread id which is used to resume the graph execution. Now, I am returning the current state information back to user thru the API and user will call another API with the same thread id so that it can continue to resume the graph. How Do I do this?

I was exploring different ways to persist the graph by serializing and deserializing it in the second API call . But none of it is working, I have tried with pickle.dumps and json.dumps. Is it possible to have a new graph created with the same checkpointer?. How do recreate the same graph in this scenario.

Upvotes: 0

Views: 48

Answers (1)

Aryan Raj
Aryan Raj

Reputation: 170

Here's how to handle a generic graph persistence and resumption with LangGraph across API calls,you may have to modify it to your need:

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Dict, Any
from langgraph.graph import StateGraph, END
import pickle
import json
from datetime import datetime
import threading

# Store active graphs in memory
active_graphs = {}

class GraphState(BaseModel):
    thread_id: str
    current_step: str
    state_data: Dict[str, Any]

class GraphCreator:
    def __init__(self):
        self.checkpointer = None
    
    def create_graph(self, params: Dict[str, Any]) -> StateGraph:
        # Example graph creation
        workflow = StateGraph()
        
        # Add nodes and edges
        workflow.add_node("start", self.start_task)
        workflow.add_node("need_human_input", self.pause_for_human)
        workflow.add_node("process_human_input", self.process_input)
        
        # Define edges
        workflow.add_edge("start", "need_human_input")
        workflow.add_edge("need_human_input", "process_human_input")
        workflow.add_edge("process_human_input", END)
        
        return workflow.compile()

    async def start_task(self, state):
        # Initial task logic
        return {"next": "need_human_input", "state": state}

    async def pause_for_human(self, state):
        # Pause execution and save state
        thread_id = state.get("thread_id")
        active_graphs[thread_id] = {
            "graph": self.graph,
            "state": state,
            "step": "need_human_input"
        }
        return {"next": "pause", "state": state}

    async def process_input(self, state):
        # Process human input and continue
        return {"next": END, "state": state}

app = FastAPI()

@app.post("/start_workflow")
async def start_workflow(params: Dict[str, Any]):
    creator = GraphCreator()
    graph = creator.create_graph(params)
    
    # Generate unique thread ID
    thread_id = f"thread_{datetime.now().timestamp()}"
    
    # Initialize state
    state = {"thread_id": thread_id, "params": params}
    
    # Store graph instance
    active_graphs[thread_id] = {
        "graph": graph,
        "state": state,
        "step": "start"
    }
    
    # Run until human input needed
    result = await graph.invoke(state)
    
    return GraphState(
        thread_id=thread_id,
        current_step="need_human_input",
        state_data=result
    )

@app.post("/resume_workflow/{thread_id}")
async def resume_workflow(thread_id: str, human_input: Dict[str, Any]):
    if thread_id not in active_graphs:
        raise HTTPException(status_code=404, detail="Thread not found")
    
    # Retrieve stored graph and state
    stored = active_graphs[thread_id]
    graph = stored["graph"]
    state = stored["state"]
    
    # Update state with human input
    state["human_input"] = human_input
    
    # Resume execution
    result = await graph.invoke(state)
    
    # Clean up if workflow completed
    if result.get("next") == END:
        del active_graphs[thread_id]
    
    return GraphState(
        thread_id=thread_id,
        current_step=result.get("next", END),
        state_data=result
    )

Instead of serializing/deserializing the graph, store active graph instances in memory using a dictionary with thread IDs as keys

Upvotes: 0

Related Questions