Reputation: 171
I am trying to build an agentic AI application using Lang Graph, I have a FAST API which accepts some parameters using which I create a Graph, I can use a check pointer as well. This graph is created inside a Graph creator class and used a local variable. Once the graph executed at specific step, I want to get Human feedback, so I interrupt the graph at a specific stage. I have a thread id which is used to resume the graph execution. Now, I am returning the current state information back to user thru the API and user will call another API with the same thread id so that it can continue to resume the graph. How Do I do this?
I was exploring different ways to persist the graph by serializing and deserializing it in the second API call . But none of it is working, I have tried with pickle.dumps and json.dumps. Is it possible to have a new graph created with the same checkpointer?. How do recreate the same graph in this scenario.
Upvotes: 0
Views: 48
Reputation: 170
Here's how to handle a generic graph persistence and resumption with LangGraph across API calls,you may have to modify it to your need:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Dict, Any
from langgraph.graph import StateGraph, END
import pickle
import json
from datetime import datetime
import threading
# Store active graphs in memory
active_graphs = {}
class GraphState(BaseModel):
thread_id: str
current_step: str
state_data: Dict[str, Any]
class GraphCreator:
def __init__(self):
self.checkpointer = None
def create_graph(self, params: Dict[str, Any]) -> StateGraph:
# Example graph creation
workflow = StateGraph()
# Add nodes and edges
workflow.add_node("start", self.start_task)
workflow.add_node("need_human_input", self.pause_for_human)
workflow.add_node("process_human_input", self.process_input)
# Define edges
workflow.add_edge("start", "need_human_input")
workflow.add_edge("need_human_input", "process_human_input")
workflow.add_edge("process_human_input", END)
return workflow.compile()
async def start_task(self, state):
# Initial task logic
return {"next": "need_human_input", "state": state}
async def pause_for_human(self, state):
# Pause execution and save state
thread_id = state.get("thread_id")
active_graphs[thread_id] = {
"graph": self.graph,
"state": state,
"step": "need_human_input"
}
return {"next": "pause", "state": state}
async def process_input(self, state):
# Process human input and continue
return {"next": END, "state": state}
app = FastAPI()
@app.post("/start_workflow")
async def start_workflow(params: Dict[str, Any]):
creator = GraphCreator()
graph = creator.create_graph(params)
# Generate unique thread ID
thread_id = f"thread_{datetime.now().timestamp()}"
# Initialize state
state = {"thread_id": thread_id, "params": params}
# Store graph instance
active_graphs[thread_id] = {
"graph": graph,
"state": state,
"step": "start"
}
# Run until human input needed
result = await graph.invoke(state)
return GraphState(
thread_id=thread_id,
current_step="need_human_input",
state_data=result
)
@app.post("/resume_workflow/{thread_id}")
async def resume_workflow(thread_id: str, human_input: Dict[str, Any]):
if thread_id not in active_graphs:
raise HTTPException(status_code=404, detail="Thread not found")
# Retrieve stored graph and state
stored = active_graphs[thread_id]
graph = stored["graph"]
state = stored["state"]
# Update state with human input
state["human_input"] = human_input
# Resume execution
result = await graph.invoke(state)
# Clean up if workflow completed
if result.get("next") == END:
del active_graphs[thread_id]
return GraphState(
thread_id=thread_id,
current_step=result.get("next", END),
state_data=result
)
Instead of serializing/deserializing the graph, store active graph instances in memory using a dictionary with thread IDs as keys
Upvotes: 0