1 核心概念
1.1 StateGraph状态图
01.基础概念
a.状态图原理
a.功能说明
StateGraph是LangGraph的核心抽象,将AI工作流建模为有向图。节点代表处理步骤,边代表数据流转。每个节点接收状态、执行操作、返回更新的状态。状态图支持条件路由、循环处理、并行执行等复杂逻辑。相比Chain的线性结构,Graph提供更灵活的控制流,适合多步骤、多分支的复杂任务。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
# 1. 定义状态
class State(TypedDict):
messages: Annotated[list, operator.add] # 消息列表,自动合并
current_step: str # 当前步骤
result: str # 最终结果
# 2. 创建状态图
workflow = StateGraph(State)
# 3. 定义节点函数
def node_a(state: State) -> State:
"""节点A:初始处理"""
print(f"执行节点A,当前步骤:{state['current_step']}")
return {
"messages": ["节点A完成"],
"current_step": "node_b"
}
def node_b(state: State) -> State:
"""节点B:后续处理"""
print(f"执行节点B,历史消息:{state['messages']}")
return {
"messages": ["节点B完成"],
"result": "处理完成"
}
# 4. 添加节点
workflow.add_node("node_a", node_a)
workflow.add_node("node_b", node_b)
# 5. 添加边
workflow.add_edge("node_a", "node_b") # A -> B
workflow.add_edge("node_b", END) # B -> END
# 6. 设置入口
workflow.set_entry_point("node_a")
# 7. 编译图
app = workflow.compile()
# 8. 执行
initial_state = {
"messages": [],
"current_step": "node_a",
"result": ""
}
result = app.invoke(initial_state)
print(f"最终结果:{result}")
---
b.与Chain对比
a.功能说明
LangChain的Chain是线性执行流,适合简单的Pipeline。LangGraph的Graph支持复杂的控制流,包括条件分支、循环、并行等。Graph更灵活但也更复杂,Chain更简单但能力有限。选择Graph还是Chain取决于任务复杂度:简单任务用Chain,复杂工作流用Graph。
b.代码示例
---
# Chain方式(线性)
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
chain = (
ChatPromptTemplate.from_template("分析:{text}")
| ChatOpenAI()
| StrOutputParser()
)
result = chain.invoke({"text": "产品评论"})
# Graph方式(灵活控制流)
from langgraph.graph import StateGraph, END
class AnalysisState(TypedDict):
text: str
sentiment: str
summary: str
actions: list
workflow = StateGraph(AnalysisState)
def analyze_sentiment(state):
# 情感分析
sentiment = sentiment_chain.invoke({"text": state["text"]})
return {"sentiment": sentiment}
def generate_summary(state):
# 生成摘要
summary = summary_chain.invoke({"text": state["text"]})
return {"summary": summary}
def decide_action(state):
# 根据情感决定后续动作
if state["sentiment"] == "负面":
return {"actions": ["转人工处理", "发送安抚消息"]}
elif state["sentiment"] == "正面":
return {"actions": ["感谢用户", "邀请分享"]}
else:
return {"actions": ["记录反馈"]}
# 构建图
workflow.add_node("analyze", analyze_sentiment)
workflow.add_node("summarize", generate_summary)
workflow.add_node("action", decide_action)
workflow.set_entry_point("analyze")
workflow.add_edge("analyze", "summarize")
workflow.add_edge("summarize", "action")
workflow.add_edge("action", END)
app = workflow.compile()
# Graph支持条件路由
def route_by_sentiment(state):
if state["sentiment"] == "负面":
return "urgent_handler"
else:
return "normal_handler"
workflow.add_conditional_edges(
"analyze",
route_by_sentiment,
{
"urgent_handler": "urgent_node",
"normal_handler": "normal_node"
}
)
---
02.图的结构
a.节点Node
a.功能说明
节点是状态图中的处理单元,每个节点是一个函数,接收当前状态、执行操作、返回状态更新。节点可以调用LLM、执行工具、处理数据等。节点函数必须返回字典,用于更新状态。支持同步和异步节点,支持错误处理和重试。
b.代码示例
---
from langgraph.graph import StateGraph
from typing import TypedDict
class State(TypedDict):
input: str
output: str
metadata: dict
# 1. 基础节点
def simple_node(state: State) -> State:
"""简单节点"""
result = process(state["input"])
return {"output": result}
# 2. LLM节点
from langchain.chat_models import ChatOpenAI
def llm_node(state: State) -> State:
"""LLM处理节点"""
llm = ChatOpenAI(temperature=0)
response = llm.predict(state["input"])
return {
"output": response,
"metadata": {"tokens": len(response)}
}
# 3. 工具调用节点
def tool_node(state: State) -> State:
"""工具调用节点"""
from langchain.tools import Tool
search_tool = Tool(
name="search",
func=search_function,
description="搜索信息"
)
result = search_tool.run(state["input"])
return {"output": result}
# 4. 异步节点
async def async_node(state: State) -> State:
"""异步节点"""
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(f"https://api.example.com?q={state['input']}") as resp:
data = await resp.json()
return {"output": data["result"]}
# 5. 错误处理节点
def resilient_node(state: State) -> State:
"""带错误处理的节点"""
try:
result = risky_operation(state["input"])
return {"output": result, "metadata": {"error": None}}
except Exception as e:
print(f"节点执行失败:{e}")
return {
"output": "处理失败",
"metadata": {"error": str(e)}
}
# 6. 多状态更新节点
def multi_update_node(state: State) -> State:
"""更新多个状态字段"""
return {
"output": "结果",
"metadata": {
"timestamp": datetime.now(),
"processed": True
},
"input": state["input"].upper() # 也可以修改输入
}
---
b.边Edge
a.功能说明
边定义节点间的连接和数据流转。普通边表示无条件转移,条件边根据状态动态选择下一个节点。边可以连接到END表示流程结束。支持配置多个出边实现分支,支持循环边实现迭代处理。合理设计边的结构决定了工作流的执行逻辑。
b.代码示例
---
from langgraph.graph import StateGraph, END
workflow = StateGraph(State)
# 1. 普通边(无条件转移)
workflow.add_edge("node_a", "node_b") # A完成后必然执行B
workflow.add_edge("node_b", "node_c")
workflow.add_edge("node_c", END) # C完成后结束
# 2. 条件边(动态路由)
def route_function(state: State) -> str:
"""路由函数:根据状态决定下一步"""
if state.get("needs_review"):
return "review_node"
elif state.get("needs_approval"):
return "approval_node"
else:
return "finish_node"
workflow.add_conditional_edges(
"process_node", # 源节点
route_function, # 路由函数
{
"review_node": "review_node",
"approval_node": "approval_node",
"finish_node": END
}
)
# 3. 循环边
def should_continue(state: State) -> str:
"""判断是否继续循环"""
if state["iteration"] < 3:
return "continue"
else:
return "done"
workflow.add_conditional_edges(
"loop_node",
should_continue,
{
"continue": "loop_node", # 回到自己,形成循环
"done": END
}
)
# 4. 多分支边
def multi_branch(state: State) -> list:
"""多分支路由"""
branches = []
if state["need_a"]:
branches.append("branch_a")
if state["need_b"]:
branches.append("branch_b")
return branches if branches else ["default"]
# 注意:多分支需要并行执行支持
# 5. 带条件的结束
def maybe_end(state: State) -> str:
"""可能结束的条件边"""
if state["is_complete"]:
return END
else:
return "next_node"
workflow.add_conditional_edges(
"check_node",
maybe_end,
{
END: END,
"next_node": "next_node"
}
)
---
1.2 状态管理
01.状态定义
a.TypedDict状态
a.功能说明
使用TypedDict定义状态结构,提供类型提示和IDE支持。状态是所有节点共享的数据容器,节点读取状态、更新状态。使用Annotated和operator定义状态字段的合并策略,如add追加列表、覆盖替换等。良好的状态设计是Graph应用的基础。
b.代码示例
---
from typing import TypedDict, Annotated, Sequence
import operator
from langchain.schema import BaseMessage
# 1. 基础状态定义
class SimpleState(TypedDict):
"""简单状态"""
input: str # 输入文本
output: str # 输出结果
step_count: int # 步骤计数
# 2. 带合并策略的状态
class AgentState(TypedDict):
"""Agent状态"""
messages: Annotated[Sequence[BaseMessage], operator.add] # 消息列表,自动追加
next_action: str # 下一步动作
intermediate_steps: Annotated[list, operator.add] # 中间步骤,追加
final_answer: str # 最终答案,覆盖
# 3. 复杂状态定义
class WorkflowState(TypedDict):
"""工作流状态"""
# 用户输入
user_query: str
user_id: str
# 处理状态
current_node: str
visited_nodes: Annotated[list, operator.add]
# 数据状态
retrieved_docs: list
analysis_results: dict
# 控制状态
needs_human: bool
max_iterations: int
current_iteration: int
# 输出状态
final_response: str
metadata: dict
# 4. 嵌套状态
from typing import Optional
class UserInfo(TypedDict):
user_id: str
role: str
permissions: list
class AdvancedState(TypedDict):
user: UserInfo
data: dict
status: str
# 5. 状态初始化
def create_initial_state(user_query: str, user_id: str) -> WorkflowState:
"""创建初始状态"""
return {
"user_query": user_query,
"user_id": user_id,
"current_node": "start",
"visited_nodes": [],
"retrieved_docs": [],
"analysis_results": {},
"needs_human": False,
"max_iterations": 10,
"current_iteration": 0,
"final_response": "",
"metadata": {}
}
# 6. 状态验证
def validate_state(state: WorkflowState) -> bool:
"""验证状态完整性"""
required_fields = ["user_query", "user_id", "current_node"]
for field in required_fields:
if field not in state or not state[field]:
print(f"缺少必需字段:{field}")
return False
if state["current_iteration"] > state["max_iterations"]:
print("超过最大迭代次数")
return False
return True
---
b.状态更新
a.功能说明
节点通过返回字典更新状态,LangGraph根据合并策略自动合并状态。add策略追加列表元素,默认策略覆盖字段值。节点只需返回要更新的字段,未返回的字段保持不变。支持部分更新、条件更新、批量更新等模式。
b.代码示例
---
from langgraph.graph import StateGraph
# 1. 基础状态更新
def node_update_basic(state: AgentState) -> AgentState:
"""基础更新:覆盖字段"""
return {
"next_action": "search", # 覆盖
"final_answer": "处理中" # 覆盖
}
# 2. 列表追加更新
def node_update_append(state: AgentState) -> AgentState:
"""追加更新:使用operator.add"""
from langchain.schema import HumanMessage
return {
"messages": [HumanMessage(content="新消息")], # 自动追加到列表
"intermediate_steps": [("action", "result")] # 自动追加
}
# 3. 条件更新
def node_conditional_update(state: WorkflowState) -> WorkflowState:
"""条件更新"""
updates = {
"current_iteration": state["current_iteration"] + 1,
"visited_nodes": [state["current_node"]]
}
# 根据条件决定更新内容
if state["current_iteration"] >= state["max_iterations"]:
updates["needs_human"] = True
updates["final_response"] = "达到最大迭代次数,需要人工介入"
return updates
# 4. 复杂状态更新
def node_complex_update(state: WorkflowState) -> WorkflowState:
"""复杂更新:更新多个字段"""
# 执行处理
results = perform_analysis(state["user_query"])
return {
"analysis_results": results,
"visited_nodes": ["analysis_node"],
"metadata": {
"timestamp": datetime.now().isoformat(),
"confidence": results.get("confidence", 0)
},
"current_node": "decision_node"
}
# 5. 增量更新
def node_incremental_update(state: AgentState) -> AgentState:
"""增量更新:基于现有状态"""
# 读取现有状态
existing_steps = state.get("intermediate_steps", [])
# 添加新步骤
new_step = ("tool_call", "tool_result")
return {
"intermediate_steps": [new_step], # operator.add会自动追加
"next_action": determine_next_action(existing_steps + [new_step])
}
# 6. 批量更新
def node_batch_update(state: WorkflowState) -> WorkflowState:
"""批量更新多个字段"""
updates = {}
# 更新文档
if state.get("needs_retrieval"):
updates["retrieved_docs"] = retrieve_documents(state["user_query"])
# 更新分析
if state.get("needs_analysis"):
updates["analysis_results"] = analyze_data(state["retrieved_docs"])
# 更新响应
if state.get("needs_response"):
updates["final_response"] = generate_response(state)
# 更新控制状态
updates["visited_nodes"] = [state["current_node"]]
updates["current_iteration"] = state["current_iteration"] + 1
return updates
---
02.状态传递
a.跨节点传递
a.功能说明
状态在节点间自动传递,每个节点接收完整的当前状态。节点更新会合并到状态中,传递给下一个节点。状态传递是只读的原始状态+写的更新,确保数据一致性。支持状态快照、状态回滚、状态分支等高级特性。
b.代码示例
---
from langgraph.graph import StateGraph, END
# 1. 状态传递示例
class DataState(TypedDict):
raw_data: str
processed_data: str
enriched_data: str
final_result: str
def node_1_raw(state: DataState) -> DataState:
"""节点1:处理原始数据"""
print(f"节点1收到:{state}")
return {"processed_data": process(state["raw_data"])}
def node_2_enrich(state: DataState) -> DataState:
"""节点2:丰富数据"""
print(f"节点2收到:{state}")
# state包含raw_data和processed_data
return {"enriched_data": enrich(state["processed_data"])}
def node_3_finalize(state: DataState) -> DataState:
"""节点3:生成最终结果"""
print(f"节点3收到:{state}")
# state包含所有之前的数据
return {"final_result": finalize(state["enriched_data"])}
# 构建图
workflow = StateGraph(DataState)
workflow.add_node("raw", node_1_raw)
workflow.add_node("enrich", node_2_enrich)
workflow.add_node("final", node_3_finalize)
workflow.set_entry_point("raw")
workflow.add_edge("raw", "enrich")
workflow.add_edge("enrich", "final")
workflow.add_edge("final", END)
app = workflow.compile()
# 执行
result = app.invoke({"raw_data": "原始数据"})
# 结果包含所有阶段的数据
print(result)
# 2. 状态过滤传递
def node_selective_read(state: WorkflowState) -> WorkflowState:
"""节点选择性读取状态"""
# 只使用需要的字段
query = state["user_query"]
docs = state.get("retrieved_docs", [])
# 处理
result = process_with_context(query, docs)
return {"analysis_results": result}
# 3. 状态转换传递
def node_transform_state(state: WorkflowState) -> WorkflowState:
"""节点转换状态格式"""
# 从一种格式转换为另一种格式
analysis = state["analysis_results"]
# 转换为响应格式
response = {
"answer": analysis.get("summary"),
"confidence": analysis.get("score"),
"sources": [doc["source"] for doc in state["retrieved_docs"]]
}
return {
"final_response": json.dumps(response),
"metadata": {"format": "json"}
}
# 4. 状态分支传递
def node_branch_state(state: WorkflowState) -> WorkflowState:
"""节点处理分支状态"""
# 根据不同分支有不同的状态结构
if state["current_node"] == "branch_a":
return {
"analysis_results": analyze_type_a(state),
"visited_nodes": ["branch_a"]
}
else:
return {
"analysis_results": analyze_type_b(state),
"visited_nodes": ["branch_b"]
}
---
1.3 图的执行
01.编译与调用
a.编译图
a.功能说明
使用compile()将StateGraph编译为可执行的应用。编译过程验证图的完整性,检查节点连接、入口出口等。编译后的应用提供invoke、stream、batch等执行方法。可以配置checkpointer实现状态持久化,配置interrupt_before/after实现人工介入。
b.代码示例
---
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
# 1. 基础编译
workflow = StateGraph(State)
# 添加节点和边
workflow.add_node("step1", node1)
workflow.add_node("step2", node2)
workflow.set_entry_point("step1")
workflow.add_edge("step1", "step2")
workflow.add_edge("step2", END)
# 编译
app = workflow.compile()
# 2. 带持久化编译
checkpointer = MemorySaver()
app_with_memory = workflow.compile(
checkpointer=checkpointer
)
# 3. 带中断编译
app_with_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["approval_node"], # 在approval_node前中断
interrupt_after=["risky_operation"] # 在risky_operation后中断
)
# 4. 调用编译后的应用
# invoke方法
result = app.invoke({"input": "Hello"})
print(result)
# 5. 带配置调用
result = app.invoke(
{"input": "Hello"},
config={"configurable": {"thread_id": "123"}}
)
# 6. 验证图结构
# 编译前验证
try:
app = workflow.compile()
except ValueError as e:
print(f"图结构错误:{e}")
# 错误示例:缺少入口点、有孤立节点、存在死循环等
# 7. 获取图信息
graph = app.get_graph()
print(f"节点:{graph.nodes}")
print(f"边:{graph.edges}")
---
b.执行模式
a.功能说明
Graph支持同步invoke、异步ainvoke、流式stream、批量batch等执行模式。invoke返回最终状态,stream逐步返回中间状态,适合实时展示。batch并行处理多个输入,提升吞吐量。可以配置并发数、超时时间、递归限制等参数。
b.代码示例
---
from langgraph.graph import StateGraph
import asyncio
# 1. 同步调用
app = workflow.compile()
result = app.invoke({"input": "query"})
print(f"最终状态:{result}")
# 2. 异步调用
async def async_execution():
result = await app.ainvoke({"input": "query"})
return result
result = asyncio.run(async_execution())
# 3. 流式执行
for state in app.stream({"input": "query"}):
print(f"当前状态:{state}")
# 每个节点执行后返回一次状态
# 4. 流式执行(详细模式)
for chunk in app.stream(
{"input": "query"},
stream_mode="values" # 或 "updates"
):
print(f"状态更新:{chunk}")
# 5. 批量执行
inputs = [
{"input": "query1"},
{"input": "query2"},
{"input": "query3"}
]
results = app.batch(inputs)
for i, result in enumerate(results):
print(f"结果{i+1}:{result}")
# 6. 配置执行参数
result = app.invoke(
{"input": "query"},
config={
"recursion_limit": 100, # 最大递归深度
"configurable": {
"thread_id": "user_123", # 会话ID
"checkpoint_id": "abc" # 检查点ID
}
}
)
# 7. 超时控制
import signal
def timeout_handler(signum, frame):
raise TimeoutError("执行超时")
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(30) # 30秒超时
try:
result = app.invoke({"input": "query"})
except TimeoutError:
print("图执行超时")
finally:
signal.alarm(0)
# 8. 错误处理
try:
result = app.invoke({"input": "query"})
except Exception as e:
print(f"执行失败:{e}")
# 可以尝试恢复或降级处理
---
02.执行控制
a.递归限制
a.功能说明
防止无限循环,配置递归深度限制。当循环次数超过限制时抛出异常。默认限制为25,可以根据需求调整。适用于包含循环边的图,如ReAct Agent、迭代优化等场景。合理设置限制平衡功能和安全性。
b.代码示例
---
from langgraph.graph import StateGraph, END
# 1. 包含循环的图
class LoopState(TypedDict):
count: int
result: str
should_continue: bool
def loop_node(state: LoopState) -> LoopState:
"""循环节点"""
new_count = state["count"] + 1
print(f"循环第{new_count}次")
return {
"count": new_count,
"should_continue": new_count < 5 # 循环5次
}
def should_loop(state: LoopState) -> str:
"""判断是否继续循环"""
if state["should_continue"]:
return "loop"
else:
return "end"
workflow = StateGraph(LoopState)
workflow.add_node("loop", loop_node)
workflow.set_entry_point("loop")
workflow.add_conditional_edges(
"loop",
should_loop,
{
"loop": "loop", # 回到自己
"end": END
}
)
app = workflow.compile()
# 2. 设置递归限制
result = app.invoke(
{"count": 0, "result": "", "should_continue": True},
config={"recursion_limit": 10} # 最多10次
)
# 3. 超过限制的处理
try:
result = app.invoke(
{"count": 0, "result": "", "should_continue": True},
config={"recursion_limit": 3} # 限制3次,但需要5次
)
except RecursionError as e:
print(f"超过递归限制:{e}")
# 可以增加限制或修改逻辑
# 4. 动态递归限制
def adaptive_recursion(state: LoopState):
"""自适应递归限制"""
# 根据状态动态调整
base_limit = 10
complexity_factor = len(state.get("result", "")) / 100
return int(base_limit * (1 + complexity_factor))
# 5. 监控递归深度
class RecursionMonitor:
"""递归深度监控"""
def __init__(self):
self.depth = 0
self.max_depth = 0
def on_node_start(self):
self.depth += 1
self.max_depth = max(self.max_depth, self.depth)
def on_node_end(self):
self.depth -= 1
def reset(self):
self.depth = 0
self.max_depth = 0
monitor = RecursionMonitor()
# 6. 安全的循环设计
class SafeLoopState(TypedDict):
iteration: int
max_iterations: int
data: list
def safe_loop_node(state: SafeLoopState) -> SafeLoopState:
"""安全的循环节点"""
return {
"iteration": state["iteration"] + 1,
"data": state["data"] + [f"step_{state['iteration']}"]
}
def safe_should_continue(state: SafeLoopState) -> str:
"""安全的继续判断"""
# 双重检查
if (state["iteration"] >= state["max_iterations"] or
len(state["data"]) >= 100): # 额外的安全检查
return "end"
return "continue"
---
2 节点与边
2.1 节点定义
01.节点函数
a.函数签名
a.功能说明
节点函数接收State类型的参数,返回State类型的更新。函数可以是同步或异步,返回值必须是字典格式。函数内部可以调用LLM、执行工具、处理数据等任意操作。支持访问完整状态、部分更新状态、条件返回等模式。节点函数是Graph的基本执行单元。
b.代码示例
---
from typing import TypedDict
from langchain.chat_models import ChatOpenAI
class State(TypedDict):
input: str
output: str
metadata: dict
# 1. 基础同步节点
def sync_node(state: State) -> State:
"""同步节点函数"""
result = process_data(state["input"])
return {"output": result}
# 2. 异步节点
async def async_node(state: State) -> State:
"""异步节点函数"""
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.example.com",
json={"data": state["input"]}
) as resp:
data = await resp.json()
return {"output": data["result"]}
# 3. LLM节点
def llm_node(state: State) -> State:
"""LLM处理节点"""
llm = ChatOpenAI(temperature=0)
response = llm.predict(f"处理:{state['input']}")
return {
"output": response,
"metadata": {
"model": "gpt-3.5-turbo",
"tokens": len(response)
}
}
# 4. 多返回字段节点
def multi_field_node(state: State) -> State:
"""更新多个字段"""
processed = process(state["input"])
return {
"output": processed["result"],
"metadata": {
"confidence": processed["confidence"],
"source": processed["source"],
"timestamp": datetime.now().isoformat()
}
}
# 5. 条件返回节点
def conditional_node(state: State) -> State:
"""根据状态条件返回"""
result = analyze(state["input"])
if result["is_valid"]:
return {
"output": result["data"],
"metadata": {"status": "success"}
}
else:
return {
"output": "invalid input",
"metadata": {"status": "error", "reason": result["reason"]}
}
# 6. 带错误处理的节点
def resilient_node(state: State) -> State:
"""带错误处理的节点"""
try:
result = risky_operation(state["input"])
return {"output": result, "metadata": {"error": None}}
except Exception as e:
logger.error(f"节点执行失败:{e}")
return {
"output": "fallback_result",
"metadata": {"error": str(e)}
}
---
b.节点配置
a.功能说明
使用add_node方法注册节点到图中,指定节点名称和函数。节点名称用于边的连接、条件路由等。可以配置节点的retry策略、timeout时间、并发控制等。支持动态添加节点、节点别名、节点组等高级特性。
b.代码示例
---
from langgraph.graph import StateGraph
workflow = StateGraph(State)
# 1. 基础添加节点
workflow.add_node("process", process_node)
workflow.add_node("analyze", analyze_node)
workflow.add_node("finalize", finalize_node)
# 2. 使用lambda节点
workflow.add_node(
"simple_transform",
lambda state: {"output": state["input"].upper()}
)
# 3. 使用偏函数节点
from functools import partial
def parameterized_node(state: State, param: str) -> State:
"""带参数的节点"""
return {"output": f"{state['input']}_{param}"}
workflow.add_node(
"param_node",
partial(parameterized_node, param="value")
)
# 4. 节点分组
# 定义一组相关节点
preprocessing_nodes = {
"clean": clean_node,
"normalize": normalize_node,
"validate": validate_node
}
for name, func in preprocessing_nodes.items():
workflow.add_node(name, func)
# 5. 条件节点工厂
def create_filter_node(threshold: float):
"""创建过滤节点"""
def filter_node(state: State) -> State:
score = calculate_score(state["input"])
if score > threshold:
return {"output": state["input"], "metadata": {"passed": True}}
else:
return {"output": "", "metadata": {"passed": False}}
return filter_node
workflow.add_node("filter_high", create_filter_node(0.8))
workflow.add_node("filter_medium", create_filter_node(0.5))
# 6. 节点装饰器
def log_node(name: str):
"""节点日志装饰器"""
def decorator(func):
def wrapper(state: State) -> State:
logger.info(f"执行节点:{name}")
logger.debug(f"输入状态:{state}")
result = func(state)
logger.info(f"节点{name}完成")
logger.debug(f"输出状态:{result}")
return result
return wrapper
return decorator
@log_node("process")
def logged_process_node(state: State) -> State:
return {"output": process(state["input"])}
workflow.add_node("process", logged_process_node)
---
02.特殊节点
a.Agent节点
a.功能说明
Agent节点封装LangChain的Agent,支持工具调用、多轮对话、ReAct推理等。Agent节点接收消息列表状态,执行Agent逻辑,返回更新的消息列表。适用于需要动态决策、工具使用的场景。可以配置Agent类型、工具集、prompt等。
b.代码示例
---
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain.tools import Tool
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import BaseMessage, HumanMessage, AIMessage
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
next_step: str
# 1. 创建Agent节点
def create_agent_node(tools: list, system_message: str):
"""创建Agent节点"""
llm = ChatOpenAI(temperature=0)
prompt = ChatPromptTemplate.from_messages([
("system", system_message),
MessagesPlaceholder(variable_name="messages"),
MessagesPlaceholder(variable_name="agent_scratchpad")
])
agent = create_openai_functions_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
def agent_node(state: AgentState) -> AgentState:
"""Agent执行节点"""
result = agent_executor.invoke({
"messages": state["messages"]
})
return {
"messages": [AIMessage(content=result["output"])]
}
return agent_node
# 2. 使用Agent节点
search_tool = Tool(
name="search",
func=search_function,
description="搜索信息"
)
calculator_tool = Tool(
name="calculator",
func=lambda x: eval(x),
description="计算数学表达式"
)
workflow = StateGraph(AgentState)
# 添加不同的Agent节点
workflow.add_node(
"researcher",
create_agent_node(
[search_tool],
"你是研究助手,负责搜索和整理信息"
)
)
workflow.add_node(
"analyst",
create_agent_node(
[calculator_tool],
"你是数据分析师,负责计算和分析"
)
)
# 3. ReAct Agent节点
from langchain.agents import create_react_agent
def react_agent_node(state: AgentState) -> AgentState:
"""ReAct推理Agent节点"""
from langchain.prompts import PromptTemplate
template = """回答以下问题,你可以使用这些工具:
{tools}
使用以下格式:
Question: 输入问题
Thought: 思考过程
Action: 工具名称
Action Input: 工具输入
Observation: 工具输出
... (重复 Thought/Action/Action Input/Observation)
Thought: 我现在知道最终答案了
Final Answer: 最终答案
开始!
Question: {input}
{agent_scratchpad}
"""
prompt = PromptTemplate.from_template(template)
agent = create_react_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
# 提取最后一条消息作为输入
last_message = state["messages"][-1]
result = agent_executor.invoke({"input": last_message.content})
return {"messages": [AIMessage(content=result["output"])]}
workflow.add_node("react_agent", react_agent_node)
---
b.工具节点
a.功能说明
工具节点封装单个工具的调用,相比Agent节点更轻量。工具节点接收输入、调用工具、返回结果。适用于确定性的工具调用场景,如API调用、数据库查询、文件操作等。可以配置重试、超时、缓存等策略。
b.代码示例
---
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
# 1. 简单工具节点
def search_tool_node(state: State) -> State:
"""搜索工具节点"""
query = state["input"]
# 调用搜索工具
results = search_api(query)
return {
"output": "\n".join(results[:3]),
"metadata": {"tool": "search", "count": len(results)}
}
# 2. API工具节点
def api_tool_node(state: State) -> State:
"""API调用工具节点"""
import requests
try:
response = requests.post(
"https://api.example.com/process",
json={"data": state["input"]},
timeout=10
)
response.raise_for_status()
data = response.json()
return {
"output": data["result"],
"metadata": {"api_status": "success"}
}
except Exception as e:
return {
"output": "API调用失败",
"metadata": {"api_status": "error", "error": str(e)}
}
# 3. 数据库工具节点
def database_tool_node(state: State) -> State:
"""数据库查询工具节点"""
import dmPython
query = state["input"]
try:
conn = dmPython.connect("dm://...")
cursor = conn.cursor()
cursor.execute("SELECT * FROM data WHERE query = ?", (query,))
results = cursor.fetchall()
conn.close()
return {
"output": str(results),
"metadata": {"db_status": "success", "rows": len(results)}
}
except Exception as e:
return {
"output": "数据库查询失败",
"metadata": {"db_status": "error", "error": str(e)}
}
# 4. 向量检索工具节点
from langchain.vectorstores import Milvus
from langchain.embeddings import OpenAIEmbeddings
def retrieval_tool_node(state: State) -> State:
"""向量检索工具节点"""
vectorstore = Milvus(
embedding_function=OpenAIEmbeddings(),
connection_args={"host": "localhost", "port": "19530"}
)
query = state["input"]
docs = vectorstore.similarity_search(query, k=3)
return {
"output": "\n\n".join([doc.page_content for doc in docs]),
"metadata": {
"tool": "retrieval",
"docs_count": len(docs),
"sources": [doc.metadata.get("source") for doc in docs]
}
}
# 5. 带重试的工具节点
from tenacity import retry, stop_after_attempt, wait_exponential
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(min=1, max=10)
)
def call_external_api(data):
"""带重试的API调用"""
response = requests.post("https://api.example.com", json=data)
response.raise_for_status()
return response.json()
def robust_api_tool_node(state: State) -> State:
"""带重试的API工具节点"""
try:
result = call_external_api({"query": state["input"]})
return {"output": result["answer"]}
except Exception as e:
return {"output": f"API调用失败(重试3次后):{e}"}
---
2.2 边类型
01.普通边
a.直接连接
a.功能说明
普通边表示无条件的节点转移,执行完源节点后直接执行目标节点。使用add_edge方法添加普通边,指定源节点名和目标节点名。普通边构成图的基本骨架,适用于确定性的顺序执行流程。
b.代码示例
---
from langgraph.graph import StateGraph, END
workflow = StateGraph(State)
# 添加节点
workflow.add_node("step1", node1)
workflow.add_node("step2", node2)
workflow.add_node("step3", node3)
# 添加普通边
workflow.add_edge("step1", "step2") # step1 -> step2
workflow.add_edge("step2", "step3") # step2 -> step3
workflow.add_edge("step3", END) # step3 -> END
# 设置入口
workflow.set_entry_point("step1")
# 执行:step1 -> step2 -> step3 -> END
app = workflow.compile()
result = app.invoke({"input": "data"})
---
2.3 条件边
01.条件路由
a.路由函数
a.功能说明
条件边根据状态动态选择下一个节点,实现分支逻辑。路由函数接收状态参数,返回下一个节点的名称或END。使用add_conditional_edges方法添加条件边,指定源节点、路由函数、路由映射。条件边是Graph灵活性的核心,支持if-elif-else、switch-case等控制流。
b.代码示例
---
from langgraph.graph import StateGraph, END
# 路由函数
def route_by_score(state: State) -> str:
score = state.get("score", 0)
if score > 0.8:
return "high_path"
elif score > 0.5:
return "medium_path"
else:
return "low_path"
workflow = StateGraph(State)
# 添加条件边
workflow.add_conditional_edges(
"scorer", # 源节点
route_by_score, # 路由函数
{
"high_path": "high_handler",
"medium_path": "medium_handler",
"low_path": "low_handler"
}
)
---
2.4 入口和出口
01.入口点
a.set_entry_point
a.功能说明
入口点是图执行的起始节点,使用set_entry_point方法设置。每个图必须有且仅有一个入口点。入口点接收invoke传入的初始状态,开始图的执行。可以根据不同场景设置不同的入口点,实现多入口图。
b.代码示例
---
from langgraph.graph import StateGraph
workflow = StateGraph(State)
# 设置入口点
workflow.set_entry_point("start_node")
# 执行时从start_node开始
app = workflow.compile()
result = app.invoke({"input": "data"})
---
02.出口点
a.END标记
a.功能说明
END是特殊的出口标记,表示图执行结束。节点可以通过边连接到END,表示执行完该节点后终止流程。支持多个节点连接到END,实现多出口图。END不是节点,只是标记,不能作为边的源节点。
b.代码示例
---
from langgraph.graph import StateGraph, END
workflow = StateGraph(State)
# 多个节点可以连接到END
workflow.add_edge("success_node", END)
workflow.add_edge("error_node", END)
workflow.add_edge("timeout_node", END)
# 条件边也可以路由到END
def maybe_end(state):
if state["is_complete"]:
return "end"
else:
return "continue"
workflow.add_conditional_edges(
"check_node",
maybe_end,
{
"end": END,
"continue": "next_node"
}
)
---
3 Multi-Agent
3.1 Agent协作模式
01.并行协作
a.多Agent并行
a.功能说明
多个Agent并行执行不同任务,最后汇总结果。适用于任务可拆分、互不依赖的场景如多源信息检索、多角度分析、并行审核等。使用RunnableParallel或多分支边实现并行执行,使用聚合节点合并结果。并行协作显著提升效率。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, Sequence
import operator
from langchain.schema import BaseMessage, HumanMessage, AIMessage
class MultiAgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
research_result: str
analysis_result: str
review_result: str
final_output: str
# 研究Agent
def research_agent(state: MultiAgentState):
"""研究Agent:搜索信息"""
query = state["messages"][-1].content
results = search_and_研究(query)
return {
"research_result": results,
"messages": [AIMessage(content=f"研究完成:{results}")]
}
# 分析Agent
def analysis_agent(state: MultiAgentState):
"""分析Agent:数据分析"""
query = state["messages"][-1].content
analysis = perform_analysis(query)
return {
"analysis_result": analysis,
"messages": [AIMessage(content=f"分析完成:{analysis}")]
}
# 审核Agent
def review_agent(state: MultiAgentState):
"""审核Agent:质量审核"""
query = state["messages"][-1].content
review = quality_check(query)
return {
"review_result": review,
"messages": [AIMessage(content=f"审核完成:{review}")]
}
# 聚合节点
def aggregate_results(state: MultiAgentState):
"""聚合所有Agent的结果"""
final = f\"\"\"
综合报告:
研究结果:{state['research_result']}
分析结果:{state['analysis_result']}
审核意见:{state['review_result']}
\"\"\"
return {"final_output": final}
# 构建并行协作图
workflow = StateGraph(MultiAgentState)
workflow.add_node("research", research_agent)
workflow.add_node("analysis", analysis_agent)
workflow.add_node("review", review_agent)
workflow.add_node("aggregate", aggregate_results)
# 从入口分发到三个Agent
workflow.set_entry_point("research") # 简化示例
# 实际应使用并行执行或分支
workflow.add_edge("research", "aggregate")
workflow.add_edge("analysis", "aggregate")
workflow.add_edge("review", "aggregate")
workflow.add_edge("aggregate", END)
---
02.顺序协作
a.流水线模式
a.功能说明
多个Agent按顺序执行,前一个Agent的输出成为后一个Agent的输入。适用于任务有先后依赖的场景如文档处理流程(提取→翻译→总结)、审批流程(初审→复审→终审)等。使用普通边串联Agent节点,状态在Agent间传递和积累。
b.代码示例
---
class PipelineState(TypedDict):
raw_doc: str
extracted_text: str
translated_text: str
summary: str
metadata: dict
# 提取Agent
def extractor_agent(state: PipelineState):
"""提取Agent:从文档提取文本"""
text = extract_from_document(state["raw_doc"])
return {
"extracted_text": text,
"metadata": {"extraction": "completed"}
}
# 翻译Agent
def translator_agent(state: PipelineState):
"""翻译Agent:翻译文本"""
translated = translate_to_english(state["extracted_text"])
return {
"translated_text": translated,
"metadata": {**state["metadata"], "translation": "completed"}
}
# 总结Agent
def summarizer_agent(state: PipelineState):
"""总结Agent:生成摘要"""
summary = generate_summary(state["translated_text"])
return {
"summary": summary,
"metadata": {**state["metadata"], "summary": "completed"}
}
# 构建流水线
workflow = StateGraph(PipelineState)
workflow.add_node("extract", extractor_agent)
workflow.add_node("translate", translator_agent)
workflow.add_node("summarize", summarizer_agent)
workflow.set_entry_point("extract")
workflow.add_edge("extract", "translate")
workflow.add_edge("translate", "summarize")
workflow.add_edge("summarize", END)
# 执行流水线
app = workflow.compile()
result = app.invoke({"raw_doc": "document.pdf", "metadata": {}})
---
3.2 消息传递
01.消息格式
a.BaseMessage类型
a.功能说明
LangGraph使用LangChain的消息类型作为Agent间通信的标准格式。主要消息类型包括HumanMessage(用户消息)、AIMessage(AI回复)、SystemMessage(系统提示)、FunctionMessage(函数调用结果)。消息携带content内容和additional_kwargs元数据。使用Annotated[Sequence[BaseMessage], operator.add]实现消息自动追加。
b.代码示例
---
from langchain.schema import (
BaseMessage,
HumanMessage,
AIMessage,
SystemMessage,
FunctionMessage
)
from typing import TypedDict, Annotated, Sequence
import operator
# 1. 定义消息状态
class MessageState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
current_agent: str
# 2. 创建不同类型的消息
human_msg = HumanMessage(
content="请帮我分析这份报告",
additional_kwargs={"user_id": "123"}
)
ai_msg = AIMessage(
content="我将为您分析报告的关键点",
additional_kwargs={"model": "gpt-4", "tokens": 20}
)
system_msg = SystemMessage(
content="你是一位专业的数据分析师"
)
function_msg = FunctionMessage(
name="search_database",
content='{"results": ["data1", "data2"]}',
additional_kwargs={"execution_time": 0.5}
)
# 3. Agent节点使用消息
def agent_node(state: MessageState) -> MessageState:
"""Agent处理消息"""
# 获取最后一条消息
last_message = state["messages"][-1]
# 根据消息类型处理
if isinstance(last_message, HumanMessage):
response = process_human_query(last_message.content)
return {
"messages": [AIMessage(content=response)],
"current_agent": "responder"
}
elif isinstance(last_message, AIMessage):
# AI消息可能需要工具调用
if "需要搜索" in last_message.content:
return {
"messages": [FunctionMessage(
name="search",
content="搜索结果..."
)],
"current_agent": "search_agent"
}
return {"messages": []}
# 4. 消息历史管理
def get_recent_messages(state: MessageState, n: int = 5):
"""获取最近N条消息"""
return state["messages"][-n:]
def filter_messages_by_type(state: MessageState, msg_type):
"""按类型过滤消息"""
return [msg for msg in state["messages"] if isinstance(msg, msg_type)]
def get_conversation_context(state: MessageState):
"""获取对话上下文"""
human_messages = filter_messages_by_type(state, HumanMessage)
ai_messages = filter_messages_by_type(state, AIMessage)
context = {
"turns": len(human_messages),
"last_human": human_messages[-1].content if human_messages else "",
"last_ai": ai_messages[-1].content if ai_messages else ""
}
return context
# 5. 消息元数据
def add_message_metadata(content: str, **metadata):
"""创建带元数据的消息"""
return AIMessage(
content=content,
additional_kwargs={
"timestamp": datetime.now().isoformat(),
"agent": "analyzer",
**metadata
}
)
# 6. 消息转换
def convert_to_langchain_messages(raw_messages: list):
"""将原始消息转换为LangChain消息"""
converted = []
for msg in raw_messages:
if msg["role"] == "user":
converted.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "assistant":
converted.append(AIMessage(content=msg["content"]))
elif msg["role"] == "system":
converted.append(SystemMessage(content=msg["content"]))
return converted
---
b.消息路由
a.功能说明
根据消息内容、类型、元数据等动态路由到不同Agent。使用条件边实现基于消息的路由逻辑。支持意图识别路由、关键词路由、负载均衡路由等策略。消息路由实现智能分发,提升Multi-Agent系统的灵活性。
b.代码示例
---
from langgraph.graph import StateGraph, END
# 1. 基于消息内容路由
def route_by_content(state: MessageState) -> str:
"""根据消息内容路由"""
last_msg = state["messages"][-1]
content = last_msg.content.lower()
# 关键词匹配
if any(kw in content for kw in ["搜索", "查询", "找"]):
return "search_agent"
elif any(kw in content for kw in ["分析", "计算", "统计"]):
return "analysis_agent"
elif any(kw in content for kw in ["总结", "概括", "归纳"]):
return "summary_agent"
else:
return "general_agent"
workflow = StateGraph(MessageState)
workflow.add_conditional_edges(
"router",
route_by_content,
{
"search_agent": "search_agent",
"analysis_agent": "analysis_agent",
"summary_agent": "summary_agent",
"general_agent": "general_agent"
}
)
# 2. 基于消息类型路由
def route_by_message_type(state: MessageState) -> str:
"""根据消息类型路由"""
last_msg = state["messages"][-1]
if isinstance(last_msg, HumanMessage):
return "human_handler"
elif isinstance(last_msg, FunctionMessage):
return "function_handler"
elif isinstance(last_msg, AIMessage):
# 检查是否需要工具调用
if last_msg.additional_kwargs.get("tool_calls"):
return "tool_executor"
return "response_handler"
return "default_handler"
# 3. 基于元数据路由
def route_by_metadata(state: MessageState) -> str:
"""根据消息元数据路由"""
last_msg = state["messages"][-1]
metadata = last_msg.additional_kwargs
# 根据优先级路由
priority = metadata.get("priority", "normal")
if priority == "urgent":
return "urgent_agent"
elif priority == "high":
return "priority_agent"
else:
return "normal_agent"
# 4. LLM意图识别路由
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
def route_by_intent(state: MessageState) -> str:
"""使用LLM识别意图并路由"""
last_msg = state["messages"][-1]
llm = ChatOpenAI(temperature=0)
prompt = ChatPromptTemplate.from_template("""
识别用户意图,返回其中之一:
- research: 需要研究和搜索信息
- analysis: 需要数据分析
- creation: 需要创建内容
- review: 需要审核检查
用户消息:{message}
只返回意图类型:""")
intent = llm.predict(prompt.format(message=last_msg.content))
intent = intent.strip().lower()
return f"{intent}_agent"
workflow.add_conditional_edges(
"intent_router",
route_by_intent,
{
"research_agent": "research_agent",
"analysis_agent": "analysis_agent",
"creation_agent": "creation_agent",
"review_agent": "review_agent"
}
)
# 5. 负载均衡路由
import random
class LoadBalancer:
"""负载均衡器"""
def __init__(self, agents: list):
self.agents = agents
self.current_loads = {agent: 0 for agent in agents}
def route(self, state: MessageState) -> str:
"""选择负载最低的Agent"""
min_load_agent = min(self.current_loads, key=self.current_loads.get)
self.current_loads[min_load_agent] += 1
return min_load_agent
def release(self, agent: str):
"""释放Agent负载"""
self.current_loads[agent] -= 1
balancer = LoadBalancer(["agent_1", "agent_2", "agent_3"])
def load_balanced_route(state: MessageState) -> str:
"""负载均衡路由"""
return balancer.route(state)
# 6. 多级路由
def hierarchical_route(state: MessageState) -> str:
"""分层路由"""
last_msg = state["messages"][-1]
# 第一层:判断类别
if "技术" in last_msg.content:
# 第二层:判断技术类型
if "编程" in last_msg.content:
return "programming_agent"
elif "算法" in last_msg.content:
return "algorithm_agent"
else:
return "tech_general_agent"
elif "商业" in last_msg.content:
if "市场" in last_msg.content:
return "marketing_agent"
elif "财务" in last_msg.content:
return "finance_agent"
else:
return "business_general_agent"
return "general_agent"
---
02.Agent通信
a.直接传递
a.功能说明
Agent间通过共享状态的messages字段直接通信。前一个Agent添加消息到列表,后一个Agent读取并响应。消息按时间顺序排列,保持完整的对话历史。直接传递简单高效,适合顺序协作的场景。
b.代码示例
---
from langgraph.graph import StateGraph, END
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
task_result: dict
# Agent A: 研究员
def researcher_agent(state: AgentState) -> AgentState:
"""研究员Agent"""
# 读取任务请求
task = state["messages"][-1].content
# 执行研究
research_results = conduct_research(task)
# 发送消息给下一个Agent
message = AIMessage(
content=f"研究完成。发现:{research_results}",
additional_kwargs={
"agent": "researcher",
"data": research_results
}
)
return {"messages": [message]}
# Agent B: 分析师
def analyst_agent(state: AgentState) -> AgentState:
"""分析师Agent"""
# 读取研究员的消息
researcher_msg = state["messages"][-1]
research_data = researcher_msg.additional_kwargs.get("data", {})
# 执行分析
analysis = analyze_data(research_data)
# 发送消息给下一个Agent
message = AIMessage(
content=f"分析完成。结论:{analysis}",
additional_kwargs={
"agent": "analyst",
"analysis": analysis
}
)
return {"messages": [message]}
# Agent C: 撰写者
def writer_agent(state: AgentState) -> AgentState:
"""撰写者Agent"""
# 读取所有之前Agent的消息
messages = state["messages"]
# 提取研究和分析结果
research_msg = [m for m in messages if m.additional_kwargs.get("agent") == "researcher"][-1]
analysis_msg = [m for m in messages if m.additional_kwargs.get("agent") == "analyst"][-1]
# 撰写报告
report = write_report(
research=research_msg.additional_kwargs["data"],
analysis=analysis_msg.additional_kwargs["analysis"]
)
return {
"messages": [AIMessage(content=f"报告完成:{report}")],
"task_result": {"report": report}
}
# 构建通信链
workflow = StateGraph(AgentState)
workflow.add_node("researcher", researcher_agent)
workflow.add_node("analyst", analyst_agent)
workflow.add_node("writer", writer_agent)
workflow.set_entry_point("researcher")
workflow.add_edge("researcher", "analyst")
workflow.add_edge("analyst", "writer")
workflow.add_edge("writer", END)
app = workflow.compile()
# 执行
result = app.invoke({
"messages": [HumanMessage(content="研究AI在医疗领域的应用")],
"task_result": {}
})
print(result["task_result"]["report"])
---
b.广播通信
a.功能说明
一个Agent的输出需要广播给多个后续Agent处理。使用多个出边实现消息广播,每个接收Agent独立处理消息。适用于需要多角度处理、并行审核、多专家评估等场景。广播后通常需要聚合节点汇总结果。
b.代码示例
---
class BroadcastState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
expert_opinions: Annotated[list, operator.add]
final_decision: str
# 协调者:发起广播
def coordinator_agent(state: BroadcastState) -> BroadcastState:
"""协调者:发送任务给所有专家"""
task = state["messages"][-1].content
message = AIMessage(
content=f"请各位专家评估:{task}",
additional_kwargs={
"agent": "coordinator",
"broadcast": True,
"task": task
}
)
return {"messages": [message]}
# 专家A:技术专家
def tech_expert_agent(state: BroadcastState) -> BroadcastState:
"""技术专家评估"""
task = state["messages"][-1].additional_kwargs["task"]
opinion = technical_evaluation(task)
return {
"messages": [AIMessage(
content=f"技术评估:{opinion}",
additional_kwargs={"expert": "tech"}
)],
"expert_opinions": [{
"expert": "tech",
"opinion": opinion,
"score": 0.8
}]
}
# 专家B:业务专家
def business_expert_agent(state: BroadcastState) -> BroadcastState:
"""业务专家评估"""
task = state["messages"][-1].additional_kwargs["task"]
opinion = business_evaluation(task)
return {
"messages": [AIMessage(
content=f"业务评估:{opinion}",
additional_kwargs={"expert": "business"}
)],
"expert_opinions": [{
"expert": "business",
"opinion": opinion,
"score": 0.75
}]
}
# 专家C:法律专家
def legal_expert_agent(state: BroadcastState) -> BroadcastState:
"""法律专家评估"""
task = state["messages"][-1].additional_kwargs["task"]
opinion = legal_evaluation(task)
return {
"messages": [AIMessage(
content=f"法律评估:{opinion}",
additional_kwargs={"expert": "legal"}
)],
"expert_opinions": [{
"expert": "legal",
"opinion": opinion,
"score": 0.9
}]
}
# 决策者:汇总专家意见
def decision_maker_agent(state: BroadcastState) -> BroadcastState:
"""决策者:综合所有专家意见"""
opinions = state["expert_opinions"]
# 加权平均
avg_score = sum(op["score"] for op in opinions) / len(opinions)
# 综合决策
if avg_score > 0.8:
decision = "批准通过"
elif avg_score > 0.6:
decision = "有条件通过"
else:
decision = "不通过"
decision_msg = f\"\"\"
综合评估结果:
技术:{opinions[0]['opinion']}
业务:{opinions[1]['opinion']}
法律:{opinions[2]['opinion']}
综合评分:{avg_score:.2f}
最终决策:{decision}
\"\"\"
return {
"messages": [AIMessage(content=decision_msg)],
"final_decision": decision
}
# 注意:实际实现广播需要使用并行执行或特殊的图结构
# 这里展示的是概念示例
workflow = StateGraph(BroadcastState)
workflow.add_node("coordinator", coordinator_agent)
workflow.add_node("tech_expert", tech_expert_agent)
workflow.add_node("business_expert", business_expert_agent)
workflow.add_node("legal_expert", legal_expert_agent)
workflow.add_node("decision_maker", decision_maker_agent)
workflow.set_entry_point("coordinator")
# 广播到三个专家(实际需要并行支持)
workflow.add_edge("coordinator", "tech_expert")
workflow.add_edge("coordinator", "business_expert")
workflow.add_edge("coordinator", "legal_expert")
# 专家意见汇总到决策者
workflow.add_edge("tech_expert", "decision_maker")
workflow.add_edge("business_expert", "decision_maker")
workflow.add_edge("legal_expert", "decision_maker")
workflow.add_edge("decision_maker", END)
---
3.3 状态共享
01.全局状态
a.共享数据结构
a.功能说明
所有Agent共享同一个状态对象,可以读取和更新共享数据。共享状态包含任务信息、中间结果、元数据等。使用TypedDict定义状态结构,使用operator.add等合并策略管理状态更新。全局状态是Multi-Agent协作的数据基础,确保信息一致性。
b.代码示例
---
from typing import TypedDict, Annotated, Sequence
import operator
from langchain.schema import BaseMessage
# 1. 定义共享状态
class SharedState(TypedDict):
# 任务信息
task_id: str
task_description: str
# Agent消息
messages: Annotated[Sequence[BaseMessage], operator.add]
# 共享数据
research_data: dict
analysis_results: dict
review_comments: list
# 进度跟踪
completed_steps: Annotated[list, operator.add]
current_agent: str
# 最终结果
final_output: str
status: str
# 2. Agent读取共享状态
def agent_read_shared(state: SharedState) -> SharedState:
"""Agent读取共享数据"""
# 读取任务描述
task = state["task_description"]
# 读取之前Agent的研究数据
research = state.get("research_data", {})
# 基于共享数据执行操作
my_result = process_with_context(task, research)
return {
"analysis_results": my_result,
"completed_steps": ["analysis"],
"current_agent": "analyst"
}
# 3. Agent更新共享状态
def agent_update_shared(state: SharedState) -> SharedState:
"""Agent更新共享数据"""
# 更新共享的研究数据
new_research = conduct_research(state["task_description"])
# 合并到现有数据
updated_research = {
**state.get("research_data", {}),
"latest": new_research,
"timestamp": datetime.now().isoformat()
}
return {
"research_data": updated_research,
"completed_steps": ["research"],
"current_agent": "researcher"
}
# 4. 多Agent协同更新
def researcher_agent(state: SharedState):
"""研究员更新研究数据"""
return {
"research_data": {"findings": [...], "sources": [...]},
"completed_steps": ["research"]
}
def analyst_agent(state: SharedState):
"""分析师使用研究数据并更新分析结果"""
# 读取研究数据
research = state["research_data"]
# 执行分析
analysis = analyze(research["findings"])
return {
"analysis_results": analysis,
"completed_steps": ["analysis"]
}
def reviewer_agent(state: SharedState):
"""审核员检查所有数据"""
# 读取研究和分析数据
research = state["research_data"]
analysis = state["analysis_results"]
# 审核
comments = review(research, analysis)
return {
"review_comments": comments,
"completed_steps": ["review"],
"status": "reviewed"
}
# 5. 状态同步
def sync_agent_state(state: SharedState) -> SharedState:
"""同步和验证状态"""
# 检查所有必需数据是否完整
required_steps = ["research", "analysis", "review"]
completed = state.get("completed_steps", [])
missing_steps = [s for s in required_steps if s not in completed]
if missing_steps:
return {
"status": "incomplete",
"final_output": f"缺少步骤:{missing_steps}"
}
else:
return {
"status": "complete",
"final_output": "所有步骤已完成"
}
# 6. 状态快照
def create_state_snapshot(state: SharedState) -> dict:
"""创建状态快照用于检查点"""
return {
"task_id": state["task_id"],
"completed_steps": list(state.get("completed_steps", [])),
"current_agent": state.get("current_agent", ""),
"status": state.get("status", "in_progress"),
"timestamp": datetime.now().isoformat()
}
---
b.数据隔离
a.功能说明
虽然使用共享状态,但每个Agent应该有明确的数据边界。Agent只更新自己负责的字段,不随意修改其他Agent的数据。使用命名空间、前缀等方式组织状态字段,避免冲突。良好的数据隔离提升系统的可维护性和可靠性。
b.代码示例
---
from typing import TypedDict, Optional
# 1. 使用命名空间隔离
class NamespacedState(TypedDict):
# 全局数据
task_id: str
messages: Annotated[Sequence[BaseMessage], operator.add]
# Agent A的命名空间
agent_a_data: dict
agent_a_status: str
# Agent B的命名空间
agent_b_data: dict
agent_b_status: str
# Agent C的命名空间
agent_c_data: dict
agent_c_status: str
# 共享结果区
shared_results: dict
# Agent A只更新自己的命名空间
def agent_a(state: NamespacedState) -> NamespacedState:
"""Agent A"""
result = process_a(state["task_id"])
return {
"agent_a_data": result,
"agent_a_status": "completed",
"shared_results": {"a_output": result["summary"]}
}
# Agent B只读取需要的数据
def agent_b(state: NamespacedState) -> NamespacedState:
"""Agent B"""
# 只读取Agent A的共享结果
a_output = state["shared_results"].get("a_output", "")
result = process_b(a_output)
return {
"agent_b_data": result,
"agent_b_status": "completed",
"shared_results": {
**state["shared_results"],
"b_output": result["summary"]
}
}
# 2. 使用嵌套字典隔离
class AgentData(TypedDict):
input: str
output: str
status: str
metadata: dict
class IsolatedState(TypedDict):
task: str
agents: dict # {"agent_name": AgentData}
final_result: str
def isolated_agent(agent_name: str, state: IsolatedState) -> IsolatedState:
"""隔离的Agent实现"""
# 读取自己的数据
my_data = state["agents"].get(agent_name, {
"input": "",
"output": "",
"status": "pending",
"metadata": {}
})
# 处理
result = process(state["task"], my_data)
# 只更新自己的数据
updated_agents = state["agents"].copy()
updated_agents[agent_name] = {
"input": state["task"],
"output": result,
"status": "completed",
"metadata": {"processed_at": datetime.now().isoformat()}
}
return {"agents": updated_agents}
# 3. 访问控制
class ControlledState(TypedDict):
public_data: dict # 所有Agent可读写
readonly_data: dict # 所有Agent只读
private_data: dict # 需要权限才能访问
def agent_with_access_control(
state: ControlledState,
agent_id: str,
permissions: list
) -> ControlledState:
"""带访问控制的Agent"""
updates = {}
# 可以读写公共数据
updates["public_data"] = {
**state.get("public_data", {}),
f"{agent_id}_contribution": "..."
}
# 不能修改只读数据
# state["readonly_data"] = ... # 禁止
# 需要权限才能访问私有数据
if "private_access" in permissions:
private_result = process_private(state["private_data"])
updates["private_data"] = {
**state.get("private_data", {}),
"updated_by": agent_id
}
return updates
# 4. 数据验证
def validate_agent_update(
state: SharedState,
update: dict,
agent_name: str
) -> bool:
"""验证Agent更新的合法性"""
# 定义每个Agent允许修改的字段
allowed_fields = {
"researcher": ["research_data", "completed_steps"],
"analyst": ["analysis_results", "completed_steps"],
"reviewer": ["review_comments", "completed_steps", "status"]
}
agent_allowed = allowed_fields.get(agent_name, [])
# 检查更新的字段是否在允许列表中
for field in update.keys():
if field not in agent_allowed:
logger.warning(
f"Agent {agent_name} 尝试更新未授权字段:{field}"
)
return False
return True
# 5. 事务性更新
class TransactionalState:
"""事务性状态管理"""
def __init__(self, initial_state: dict):
self.state = initial_state.copy()
self.pending_updates = []
def propose_update(self, agent: str, update: dict):
"""Agent提议更新"""
self.pending_updates.append({
"agent": agent,
"update": update,
"timestamp": datetime.now()
})
def commit_updates(self):
"""提交所有更新"""
for item in self.pending_updates:
if validate_agent_update(self.state, item["update"], item["agent"]):
self.state.update(item["update"])
else:
logger.error(f"拒绝 {item['agent']} 的更新")
self.pending_updates.clear()
def rollback(self):
"""回滚所有未提交的更新"""
self.pending_updates.clear()
---
02.状态同步
a.一致性保证
a.功能说明
在Multi-Agent并发执行时,需要保证状态的一致性。LangGraph使用状态合并策略(如operator.add)自动处理并发更新。对于需要原子性的操作,可以使用锁、事务等机制。状态一致性确保Agent协作的正确性,避免数据竞争和冲突。
b.代码示例
---
import threading
from typing import TypedDict
# 1. 使用锁保证原子性
state_lock = threading.Lock()
def atomic_agent_update(state: SharedState) -> SharedState:
"""原子性的状态更新"""
with state_lock:
# 读取当前值
current_count = state.get("counter", 0)
# 更新
new_count = current_count + 1
return {"counter": new_count}
# 2. 使用operator.add自动合并
class MergeableState(TypedDict):
# 使用add策略,多个Agent的更新会自动合并
results: Annotated[list, operator.add]
counts: Annotated[dict, operator.add] # 注意:dict的add是更新而非合并
def agent_concurrent_update(state: MergeableState) -> MergeableState:
"""并发安全的更新"""
return {
"results": [f"result_from_{agent_id}"], # 自动追加
"counts": {agent_id: 1} # 自动合并到dict
}
# 3. 版本控制
class VersionedState(TypedDict):
data: dict
version: int
last_updated_by: str
def versioned_update(state: VersionedState, agent: str) -> VersionedState:
"""带版本控制的更新"""
# 检查版本
current_version = state.get("version", 0)
# 更新数据
new_data = modify_data(state["data"])
return {
"data": new_data,
"version": current_version + 1,
"last_updated_by": agent
}
# 4. 冲突检测
def detect_conflicts(state: SharedState, updates: list) -> list:
"""检测状态更新冲突"""
conflicts = []
# 检查多个Agent是否尝试更新同一字段
updated_fields = {}
for update in updates:
for field in update.keys():
if field in updated_fields:
conflicts.append({
"field": field,
"agents": [updated_fields[field], update["agent"]]
})
updated_fields[field] = update.get("agent", "unknown")
return conflicts
# 5. 最终一致性
class EventuallyConsistentState:
"""最终一致性状态管理"""
def __init__(self):
self.state = {}
self.update_log = []
def apply_update(self, update: dict, timestamp: float):
"""记录更新"""
self.update_log.append({
"update": update,
"timestamp": timestamp
})
def synchronize(self):
"""同步所有更新"""
# 按时间戳排序
sorted_updates = sorted(
self.update_log,
key=lambda x: x["timestamp"]
)
# 依次应用
for item in sorted_updates:
self.state.update(item["update"])
self.update_log.clear()
---
3.4 Supervisor模式
01.Supervisor架构
a.中心协调
a.功能说明
Supervisor模式使用一个中央协调Agent管理多个Worker Agent。Supervisor接收任务、分配给合适的Worker、汇总结果。Supervisor负责任务路由、进度跟踪、错误处理、结果聚合等。Worker Agent专注于执行具体任务,无需了解整体流程。Supervisor模式实现集中控制的Multi-Agent系统。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, Sequence, Literal
import operator
from langchain.schema import BaseMessage, HumanMessage, AIMessage
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
# 1. 定义状态
class SupervisorState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
next_worker: str
task_queue: list
completed_tasks: list
final_result: str
# 2. 定义Worker Agent
members = ["researcher", "analyst", "writer"]
def create_worker_agent(name: str, role: str):
"""创建Worker Agent"""
def worker_node(state: SupervisorState):
# 获取最后一条消息(来自Supervisor的指令)
last_msg = state["messages"][-1]
# 执行角色特定的任务
if name == "researcher":
result = conduct_research(last_msg.content)
elif name == "analyst":
result = analyze_data(last_msg.content)
elif name == "writer":
result = write_content(last_msg.content)
# 返回结果给Supervisor
return {
"messages": [AIMessage(
content=f"{name}完成:{result}",
additional_kwargs={"worker": name}
)]
}
return worker_node
# 3. 创建Supervisor Agent
def create_supervisor():
"""创建Supervisor Agent"""
llm = ChatOpenAI(model="gpt-4", temperature=0)
options = ["FINISH"] + members
system_prompt = f"""你是团队主管,管理以下团队成员:{members}。
根据任务需求,选择合适的成员执行任务。
可选项:{options}
- 选择团队成员名称:将任务分配给该成员
- 选择FINISH:所有任务完成,结束流程
只返回选择的选项名称。"""
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
("human", "{messages}")
])
def supervisor_node(state: SupervisorState):
"""Supervisor决策节点"""
# 分析当前状态和任务
messages = state["messages"]
# 调用LLM决定下一步
response = llm.predict(
prompt.format(messages="\n".join([m.content for m in messages]))
)
next_worker = response.strip()
if next_worker not in options:
next_worker = "FINISH"
# 如果选择了Worker,发送任务指令
if next_worker != "FINISH":
return {
"next_worker": next_worker,
"messages": [AIMessage(
content=f"分配任务给{next_worker}",
additional_kwargs={"supervisor": True}
)]
}
else:
return {"next_worker": "FINISH"}
return supervisor_node
# 4. 路由函数
def supervisor_router(state: SupervisorState) -> Literal["researcher", "analyst", "writer", "FINISH"]:
"""Supervisor路由决策"""
next_worker = state.get("next_worker", "FINISH")
return next_worker
# 5. 构建Supervisor图
workflow = StateGraph(SupervisorState)
# 添加Supervisor节点
workflow.add_node("supervisor", create_supervisor())
# 添加Worker节点
workflow.add_node("researcher", create_worker_agent("researcher", "研究"))
workflow.add_node("analyst", create_worker_agent("analyst", "分析"))
workflow.add_node("writer", create_worker_agent("writer", "撰写"))
# Worker完成后返回Supervisor
for member in members:
workflow.add_edge(member, "supervisor")
# Supervisor路由到Worker或结束
workflow.add_conditional_edges(
"supervisor",
supervisor_router,
{
"researcher": "researcher",
"analyst": "analyst",
"writer": "writer",
"FINISH": END
}
)
# 设置入口
workflow.set_entry_point("supervisor")
# 6. 执行
app = workflow.compile()
result = app.invoke({
"messages": [HumanMessage(content="写一份关于AI趋势的报告")],
"next_worker": "",
"task_queue": [],
"completed_tasks": [],
"final_result": ""
})
print("执行轨迹:")
for msg in result["messages"]:
print(f"- {msg.content}")
---
b.任务调度
a.功能说明
Supervisor根据任务类型、Worker状态、优先级等因素智能调度任务。支持任务队列、负载均衡、优先级排序、动态分配等策略。Supervisor监控Worker执行状态,处理超时、失败、重试等异常情况。任务调度优化Multi-Agent系统的执行效率和可靠性。
b.代码示例
---
from typing import TypedDict, List
from dataclasses import dataclass
from enum import Enum
import heapq
# 1. 任务定义
class TaskPriority(Enum):
LOW = 3
MEDIUM = 2
HIGH = 1
URGENT = 0
@dataclass
class Task:
id: str
description: str
priority: TaskPriority
assigned_to: str = ""
status: str = "pending"
result: str = ""
class SchedulerState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
task_queue: list
active_tasks: dict
completed_tasks: list
worker_status: dict
# 2. 任务队列管理
class PriorityTaskQueue:
"""优先级任务队列"""
def __init__(self):
self.queue = []
self.counter = 0
def add_task(self, task: Task):
"""添加任务(按优先级排序)"""
# 使用堆实现优先级队列
heapq.heappush(
self.queue,
(task.priority.value, self.counter, task)
)
self.counter += 1
def get_next_task(self) -> Task:
"""获取下一个最高优先级任务"""
if self.queue:
_, _, task = heapq.heappop(self.queue)
return task
return None
def is_empty(self) -> bool:
return len(self.queue) == 0
# 3. Worker状态跟踪
class WorkerTracker:
"""Worker状态追踪器"""
def __init__(self, workers: List[str]):
self.workers = {
worker: {
"status": "idle",
"current_task": None,
"completed_count": 0,
"load": 0
}
for worker in workers
}
def get_available_worker(self) -> str:
"""获取可用的Worker(负载最低)"""
idle_workers = [
(name, info)
for name, info in self.workers.items()
if info["status"] == "idle"
]
if idle_workers:
# 返回负载最低的Worker
return min(idle_workers, key=lambda x: x[1]["load"])[0]
return None
def assign_task(self, worker: str, task: Task):
"""分配任务给Worker"""
self.workers[worker]["status"] = "busy"
self.workers[worker]["current_task"] = task.id
self.workers[worker]["load"] += 1
def complete_task(self, worker: str):
"""Worker完成任务"""
self.workers[worker]["status"] = "idle"
self.workers[worker]["current_task"] = None
self.workers[worker]["completed_count"] += 1
# 4. 智能调度Supervisor
class SmartSupervisor:
"""智能调度Supervisor"""
def __init__(self, workers: List[str]):
self.task_queue = PriorityTaskQueue()
self.worker_tracker = WorkerTracker(workers)
self.active_tasks = {}
self.completed_tasks = []
def add_task(self, task: Task):
"""添加任务到队列"""
self.task_queue.add_task(task)
def schedule(self) -> tuple:
"""调度下一个任务"""
# 获取可用Worker
available_worker = self.worker_tracker.get_available_worker()
if not available_worker:
return None, "所有Worker都在忙"
# 获取下一个任务
next_task = self.task_queue.get_next_task()
if not next_task:
return None, "任务队列为空"
# 分配任务
next_task.assigned_to = available_worker
next_task.status = "running"
self.worker_tracker.assign_task(available_worker, next_task)
self.active_tasks[next_task.id] = next_task
return available_worker, next_task
def task_completed(self, task_id: str, result: str):
"""任务完成回调"""
task = self.active_tasks.pop(task_id, None)
if task:
task.status = "completed"
task.result = result
self.completed_tasks.append(task)
self.worker_tracker.complete_task(task.assigned_to)
def get_status(self) -> dict:
"""获取系统状态"""
return {
"queued_tasks": len(self.task_queue.queue),
"active_tasks": len(self.active_tasks),
"completed_tasks": len(self.completed_tasks),
"worker_status": self.worker_tracker.workers
}
# 5. 集成到LangGraph
def create_scheduling_supervisor(workers: List[str]):
"""创建调度Supervisor"""
supervisor = SmartSupervisor(workers)
def supervisor_node(state: SchedulerState):
"""Supervisor调度节点"""
# 从状态获取新任务
for task_dict in state.get("task_queue", []):
task = Task(**task_dict)
supervisor.add_task(task)
# 调度任务
worker, task = supervisor.schedule()
if worker and task:
return {
"messages": [AIMessage(
content=f"分配任务{task.id}给{worker}",
additional_kwargs={
"task": task.id,
"worker": worker
}
)],
"active_tasks": supervisor.active_tasks,
"next_worker": worker
}
else:
return {
"messages": [AIMessage(content="没有可调度的任务")],
"next_worker": "FINISH"
}
return supervisor_node
# 6. Worker反馈机制
def worker_with_feedback(worker_name: str, supervisor: SmartSupervisor):
"""带反馈的Worker"""
def worker_node(state: SchedulerState):
# 获取分配的任务
last_msg = state["messages"][-1]
task_id = last_msg.additional_kwargs.get("task")
# 执行任务
result = execute_task(task_id)
# 反馈给Supervisor
supervisor.task_completed(task_id, result)
return {
"messages": [AIMessage(
content=f"{worker_name}完成任务{task_id}",
additional_kwargs={
"worker": worker_name,
"task": task_id,
"result": result
}
)]
}
return worker_node
---
4 持久化
4.1 Checkpointing
01.检查点机制
a.自动保存
a.功能说明
Checkpointing在图执行过程中自动保存状态快照,支持中断恢复、回溯、分支等功能。每执行完一个节点就保存一次检查点,记录完整的状态和执行路径。使用checkpointer参数配置存储后端,如内存、文件、数据库等。检查点是实现人工介入、长时间任务、容错恢复的基础。
b.代码示例
---
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint.sqlite import SqliteSaver
from typing import TypedDict
# 1. 使用内存检查点
checkpointer = MemorySaver()
workflow = StateGraph(State)
# ... 添加节点和边 ...
app = workflow.compile(checkpointer=checkpointer)
# 执行时自动保存检查点
result = app.invoke(
{"input": "data"},
config={"configurable": {"thread_id": "conversation_1"}}
)
# 2. 使用SQLite检查点
checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
app = workflow.compile(checkpointer=checkpointer)
# thread_id区分不同的执行实例
result = app.invoke(
{"input": "data"},
config={"configurable": {"thread_id": "user_123_session_1"}}
)
# 3. 查看检查点历史
# 获取所有检查点
checkpoints = list(app.checkpointer.list(
config={"configurable": {"thread_id": "user_123_session_1"}}
))
print(f"共有{len(checkpoints)}个检查点")
for i, checkpoint in enumerate(checkpoints):
print(f"检查点{i+1}:{checkpoint}")
# 4. 从检查点恢复
# 获取最新检查点
latest_checkpoint = app.checkpointer.get(
config={"configurable": {"thread_id": "user_123_session_1"}}
)
if latest_checkpoint:
# 从检查点继续执行
result = app.invoke(
None, # 输入为None表示从检查点恢复
config={"configurable": {"thread_id": "user_123_session_1"}}
)
# 5. 检查点元数据
from datetime import datetime
def create_checkpoint_with_metadata(state: State):
"""创建带元数据的检查点"""
metadata = {
"timestamp": datetime.now().isoformat(),
"node": state.get("current_node", ""),
"step_count": state.get("step_count", 0),
"user_id": state.get("user_id", "")
}
return metadata
# 6. 自定义Checkpointer
from langgraph.checkpoint.base import BaseCheckpointSaver
import json
import redis
class RedisCheckpointer(BaseCheckpointSaver):
"""Redis检查点存储"""
def __init__(self, redis_client):
self.redis = redis_client
def put(self, config, checkpoint, metadata):
"""保存检查点"""
thread_id = config["configurable"]["thread_id"]
checkpoint_id = checkpoint["id"]
key = f"checkpoint:{thread_id}:{checkpoint_id}"
self.redis.set(
key,
json.dumps({
"checkpoint": checkpoint,
"metadata": metadata
}),
ex=86400 * 7 # 7天过期
)
def get(self, config):
"""获取最新检查点"""
thread_id = config["configurable"]["thread_id"]
# 查找所有检查点
keys = self.redis.keys(f"checkpoint:{thread_id}:*")
if not keys:
return None
# 获取最新的
latest_key = sorted(keys)[-1]
data = self.redis.get(latest_key)
return json.loads(data) if data else None
def list(self, config):
"""列出所有检查点"""
thread_id = config["configurable"]["thread_id"]
keys = self.redis.keys(f"checkpoint:{thread_id}:*")
checkpoints = []
for key in sorted(keys):
data = self.redis.get(key)
if data:
checkpoints.append(json.loads(data))
return checkpoints
# 使用Redis检查点
redis_client = redis.Redis(host='localhost', port=6379)
checkpointer = RedisCheckpointer(redis_client)
app = workflow.compile(checkpointer=checkpointer)
---
b.增量保存
a.功能说明
只保存状态的增量变化而非完整状态,减少存储空间和I/O开销。记录每个节点的状态diff,恢复时从初始状态应用所有diff。适用于状态体积大、变化频繁的场景。增量保存平衡了持久化能力和性能开销。
b.代码示例
---
from typing import Any
import json
# 1. 计算状态diff
def compute_state_diff(old_state: dict, new_state: dict) -> dict:
"""计算状态差异"""
diff = {}
for key, new_value in new_state.items():
old_value = old_state.get(key)
if old_value != new_value:
diff[key] = {
"old": old_value,
"new": new_value
}
return diff
# 2. 应用diff恢复状态
def apply_state_diff(base_state: dict, diff: dict) -> dict:
"""应用差异恢复状态"""
restored_state = base_state.copy()
for key, change in diff.items():
restored_state[key] = change["new"]
return restored_state
# 3. 增量Checkpointer
class IncrementalCheckpointer:
"""增量检查点存储"""
def __init__(self, storage):
self.storage = storage
self.base_states = {} # {thread_id: base_state}
self.diffs = {} # {thread_id: [diff1, diff2, ...]}
def save(self, thread_id: str, state: dict, node_name: str):
"""保存增量"""
if thread_id not in self.base_states:
# 首次保存完整状态
self.base_states[thread_id] = state
self.diffs[thread_id] = []
self.storage.set(
f"base:{thread_id}",
json.dumps(state)
)
else:
# 计算并保存diff
base = self.base_states[thread_id]
diff = compute_state_diff(base, state)
self.diffs[thread_id].append({
"node": node_name,
"diff": diff,
"timestamp": datetime.now().isoformat()
})
# 更新基准状态
self.base_states[thread_id] = state
# 保存diff历史
self.storage.set(
f"diffs:{thread_id}",
json.dumps(self.diffs[thread_id])
)
def restore(self, thread_id: str, checkpoint_index: int = -1) -> dict:
"""恢复到指定检查点"""
# 加载基准状态
base_data = self.storage.get(f"base:{thread_id}")
base_state = json.loads(base_data) if base_data else {}
# 加载diffs
diffs_data = self.storage.get(f"diffs:{thread_id}")
diffs = json.loads(diffs_data) if diffs_data else []
# 应用diffs到指定位置
restored_state = base_state
for diff_item in diffs[:checkpoint_index]:
restored_state = apply_state_diff(
restored_state,
diff_item["diff"]
)
return restored_state
# 4. 使用增量检查点
import redis
redis_client = redis.Redis()
incremental_cp = IncrementalCheckpointer(redis_client)
# 在节点中保存
def node_with_checkpoint(state: State) -> State:
result = process(state)
# 保存增量
incremental_cp.save(
thread_id="session_1",
state=result,
node_name="process_node"
)
return result
---
4.2 MemorySaver
01.内存存储
a.基础用法
a.功能说明
MemorySaver是LangGraph提供的内存检查点存储器,将检查点保存在进程内存中。适用于开发测试、短期会话、单机运行等场景。进程重启后数据丢失,不支持分布式访问。MemorySaver简单轻量,无需外部依赖,是快速原型开发的首选。
b.代码示例
---
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from typing import TypedDict
# 1. 创建MemorySaver
checkpointer = MemorySaver()
# 2. 编译图时配置
workflow = StateGraph(State)
# ... 添加节点和边 ...
app = workflow.compile(checkpointer=checkpointer)
# 3. 使用检查点
config = {"configurable": {"thread_id": "session_1"}}
# 首次执行
result1 = app.invoke({"input": "step1"}, config=config)
# 继续执行(从检查点恢复)
result2 = app.invoke({"input": "step2"}, config=config)
# 4. 多会话支持
# 会话1
app.invoke(
{"input": "data1"},
config={"configurable": {"thread_id": "user_1"}}
)
# 会话2(独立的检查点)
app.invoke(
{"input": "data2"},
config={"configurable": {"thread_id": "user_2"}}
)
# 5. 检查点查看
# 获取检查点
checkpoint = checkpointer.get(config)
print(f"当前状态:{checkpoint}")
# 列出所有检查点
all_checkpoints = list(checkpointer.list(config))
print(f"检查点数量:{len(all_checkpoints)}")
# 6. 清理检查点
# 删除特定会话的检查点
def clear_session(thread_id: str):
"""清理会话检查点"""
config = {"configurable": {"thread_id": thread_id}}
# MemorySaver没有直接删除方法,可以通过创建新实例清空
checkpointer = MemorySaver()
# 7. 检查点容量限制
class LimitedMemorySaver(MemorySaver):
"""限制容量的MemorySaver"""
def __init__(self, max_checkpoints: int = 100):
super().__init__()
self.max_checkpoints = max_checkpoints
self.checkpoint_count = 0
def put(self, config, checkpoint, metadata):
"""保存检查点with容量限制"""
if self.checkpoint_count >= self.max_checkpoints:
# 删除最旧的检查点
self._evict_oldest()
super().put(config, checkpoint, metadata)
self.checkpoint_count += 1
def _evict_oldest(self):
"""清除最旧的检查点"""
# 实现LRU策略
pass
checkpointer = LimitedMemorySaver(max_checkpoints=50)
---
b.性能优化
a.功能说明
MemorySaver性能优良但内存占用随检查点数量增长。需要设置合理的保留策略,如只保留最近N个检查点、定期压缩旧检查点等。对于大状态对象,可以只保存关键字段或使用引用。性能优化平衡记忆能力和资源消耗。
b.代码示例
---
from collections import deque
from typing import Any
# 1. 滑动窗口保留
class SlidingWindowMemorySaver(MemorySaver):
"""滑动窗口MemorySaver"""
def __init__(self, window_size: int = 10):
super().__init__()
self.window_size = window_size
self.checkpoints = {} # {thread_id: deque}
def put(self, config, checkpoint, metadata):
"""保存检查点(滑动窗口)"""
thread_id = config["configurable"]["thread_id"]
if thread_id not in self.checkpoints:
self.checkpoints[thread_id] = deque(maxlen=self.window_size)
self.checkpoints[thread_id].append({
"checkpoint": checkpoint,
"metadata": metadata
})
def get(self, config):
"""获取最新检查点"""
thread_id = config["configurable"]["thread_id"]
if thread_id in self.checkpoints and self.checkpoints[thread_id]:
return self.checkpoints[thread_id][-1]["checkpoint"]
return None
# 2. 压缩存储
import pickle
import gzip
class CompressedMemorySaver(MemorySaver):
"""压缩存储的MemorySaver"""
def put(self, config, checkpoint, metadata):
"""压缩保存"""
# 序列化
data = pickle.dumps({"checkpoint": checkpoint, "metadata": metadata})
# 压缩
compressed = gzip.compress(data)
# 存储(使用父类的存储)
thread_id = config["configurable"]["thread_id"]
self._storage[thread_id] = compressed
def get(self, config):
"""解压获取"""
thread_id = config["configurable"]["thread_id"]
compressed = self._storage.get(thread_id)
if compressed:
# 解压
data = gzip.decompress(compressed)
# 反序列化
obj = pickle.loads(data)
return obj["checkpoint"]
return None
# 3. 选择性保存
def selective_checkpoint_saver(important_fields: list):
"""只保存重要字段"""
class SelectiveSaver(MemorySaver):
def put(self, config, checkpoint, metadata):
# 只保存重要字段
slim_checkpoint = {
k: v for k, v in checkpoint.items()
if k in important_fields
}
super().put(config, slim_checkpoint, metadata)
return SelectiveSaver()
# 只保存关键字段
checkpointer = selective_checkpoint_saver([
"messages", "current_step", "final_result"
])
# 4. 延迟保存
class LazyMemorySaver(MemorySaver):
"""延迟保存检查点"""
def __init__(self, save_interval: int = 3):
super().__init__()
self.save_interval = save_interval
self.save_counter = {}
def put(self, config, checkpoint, metadata):
"""每N次才实际保存"""
thread_id = config["configurable"]["thread_id"]
self.save_counter[thread_id] = self.save_counter.get(thread_id, 0) + 1
# 只在第N次时保存
if self.save_counter[thread_id] % self.save_interval == 0:
super().put(config, checkpoint, metadata)
# 每3个节点保存一次
checkpointer = LazyMemorySaver(save_interval=3)
---
4.3 数据库持久化
01.SQLite持久化
a.SqliteSaver
a.功能说明
SqliteSaver将检查点存储到SQLite数据库,提供持久化和跨进程访问能力。SQLite是轻量级嵌入式数据库,无需独立服务,适合单机应用和中小规模部署。支持完整的ACID事务,数据可靠性高。SqliteSaver是生产环境的推荐选择。
b.代码示例
---
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import StateGraph, END
# 1. 创建SQLite Checkpointer
checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
workflow = StateGraph(State)
# ... 添加节点和边 ...
app = workflow.compile(checkpointer=checkpointer)
# 2. 使用持久化检查点
config = {"configurable": {"thread_id": "user_123"}}
# 首次执行
result = app.invoke({"input": "data"}, config=config)
# 进程重启后仍可恢复
# 重新创建应用
checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
app = workflow.compile(checkpointer=checkpointer)
# 从数据库恢复并继续
result = app.invoke(None, config=config)
# 3. 检查点查询
# 列出用户的所有会话
user_sessions = []
# 获取检查点详情
checkpoint = checkpointer.get(config)
if checkpoint:
print(f"检查点ID:{checkpoint.get('id')}")
print(f"状态:{checkpoint.get('channel_values')}")
# 4. 检查点历史
checkpoints = list(checkpointer.list(config))
print(f"会话检查点历史:")
for i, cp in enumerate(checkpoints):
print(f"{i+1}. ID: {cp.get('id')}, 步骤: {cp.get('metadata', {}).get('step')}")
# 5. 数据库维护
import sqlite3
# 清理过期检查点
def cleanup_old_checkpoints(db_path: str, days: int = 7):
"""清理N天前的检查点"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# 删除旧记录
cursor.execute("""
DELETE FROM checkpoints
WHERE created_at < datetime('now', ?)
""", (f'-{days} days',))
deleted = cursor.rowcount
conn.commit()
conn.close()
print(f"清理了{deleted}个过期检查点")
# 定期清理
cleanup_old_checkpoints("checkpoints.db", days=7)
# 6. 数据库配置
# 自定义表结构
conn = sqlite3.connect("custom_checkpoints.db")
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT,
checkpoint_id TEXT,
state BLOB,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (thread_id, checkpoint_id)
)
""")
# 创建索引加速查询
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_thread_created
ON checkpoints(thread_id, created_at DESC)
""")
conn.commit()
conn.close()
---
b.并发控制
a.功能说明
SQLite支持多读单写的并发模式,需要处理并发访问时的锁冲突。使用WAL模式提升并发性能,配置busy_timeout处理锁等待。对于高并发场景,建议使用PostgreSQL等企业级数据库。合理的并发控制确保数据一致性和系统性能。
b.代码示例
---
import sqlite3
from contextlib import contextmanager
# 1. 配置WAL模式
conn = sqlite3.connect("checkpoints.db")
conn.execute("PRAGMA journal_mode=WAL") # 写前日志模式
conn.execute("PRAGMA synchronous=NORMAL") # 适中的同步级别
conn.execute("PRAGMA busy_timeout=5000") # 5秒超时
conn.close()
# 2. 连接池
class SQLiteConnectionPool:
"""SQLite连接池"""
def __init__(self, db_path: str, pool_size: int = 5):
self.db_path = db_path
self.pool = []
for _ in range(pool_size):
conn = sqlite3.connect(db_path, check_same_thread=False)
conn.execute("PRAGMA journal_mode=WAL")
self.pool.append(conn)
@contextmanager
def get_connection(self):
"""获取连接"""
conn = self.pool.pop() if self.pool else sqlite3.connect(self.db_path)
try:
yield conn
finally:
if len(self.pool) < 5:
self.pool.append(conn)
else:
conn.close()
pool = SQLiteConnectionPool("checkpoints.db")
# 使用连接池
with pool.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM checkpoints WHERE thread_id = ?", ("user_1",))
results = cursor.fetchall()
# 3. 重试机制
from tenacity import retry, stop_after_attempt, wait_fixed
@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(0.5)
)
def save_checkpoint_with_retry(conn, thread_id, data):
"""带重试的保存"""
try:
cursor = conn.cursor()
cursor.execute(
"INSERT INTO checkpoints (thread_id, state) VALUES (?, ?)",
(thread_id, data)
)
conn.commit()
except sqlite3.OperationalError as e:
if "locked" in str(e):
raise # 重试
else:
raise ValueError(f"保存失败:{e}")
# 4. 批量操作
def batch_save_checkpoints(checkpoints: list):
"""批量保存检查点"""
conn = sqlite3.connect("checkpoints.db")
cursor = conn.cursor()
try:
cursor.executemany(
"INSERT INTO checkpoints (thread_id, state) VALUES (?, ?)",
[(cp["thread_id"], cp["state"]) for cp in checkpoints]
)
conn.commit()
except Exception as e:
conn.rollback()
raise e
finally:
conn.close()
---
02.PostgreSQL持久化
a.生产级存储
a.功能说明
PostgreSQL提供企业级的持久化能力,支持高并发、大数据量、分布式访问。使用psycopg2或asyncpg驱动连接数据库,创建自定义Checkpointer实现。PostgreSQL支持完整的ACID事务、复杂查询、索引优化等特性。适用于生产环境、多实例部署、高可用系统。
b.代码示例
---
import psycopg2
from psycopg2.extras import Json
from langgraph.checkpoint.base import BaseCheckpointSaver
import json
# 1. PostgreSQL Checkpointer
class PostgresCheckpointer(BaseCheckpointSaver):
"""PostgreSQL检查点存储"""
def __init__(self, connection_string: str):
self.conn_str = connection_string
self._ensure_table()
def _ensure_table(self):
"""创建表结构"""
conn = psycopg2.connect(self.conn_str)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS checkpoints (
thread_id VARCHAR(100),
checkpoint_id VARCHAR(100),
checkpoint_data JSONB,
metadata JSONB,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (thread_id, checkpoint_id)
)
""")
# 创建索引
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_thread_created
ON checkpoints(thread_id, created_at DESC)
""")
conn.commit()
conn.close()
def put(self, config, checkpoint, metadata):
"""保存检查点"""
thread_id = config["configurable"]["thread_id"]
checkpoint_id = checkpoint.get("id", "")
conn = psycopg2.connect(self.conn_str)
cursor = conn.cursor()
cursor.execute("""
INSERT INTO checkpoints
(thread_id, checkpoint_id, checkpoint_data, metadata)
VALUES (%s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_id)
DO UPDATE SET
checkpoint_data = EXCLUDED.checkpoint_data,
metadata = EXCLUDED.metadata
""", (
thread_id,
checkpoint_id,
Json(checkpoint),
Json(metadata)
))
conn.commit()
conn.close()
def get(self, config):
"""获取最新检查点"""
thread_id = config["configurable"]["thread_id"]
conn = psycopg2.connect(self.conn_str)
cursor = conn.cursor()
cursor.execute("""
SELECT checkpoint_data
FROM checkpoints
WHERE thread_id = %s
ORDER BY created_at DESC
LIMIT 1
""", (thread_id,))
result = cursor.fetchone()
conn.close()
return result[0] if result else None
def list(self, config):
"""列出所有检查点"""
thread_id = config["configurable"]["thread_id"]
conn = psycopg2.connect(self.conn_str)
cursor = conn.cursor()
cursor.execute("""
SELECT checkpoint_data, metadata, created_at
FROM checkpoints
WHERE thread_id = %s
ORDER BY created_at DESC
""", (thread_id,))
results = cursor.fetchall()
conn.close()
return [
{
"checkpoint": row[0],
"metadata": row[1],
"created_at": row[2]
}
for row in results
]
# 2. 使用PostgreSQL
pg_checkpointer = PostgresCheckpointer(
"postgresql://user:pass@localhost:5432/langchain"
)
app = workflow.compile(checkpointer=pg_checkpointer)
# 3. 达梦数据库持久化
import dmPython
class DmCheckpointer(BaseCheckpointSaver):
"""达梦数据库检查点存储"""
def __init__(self, connection_string: str):
self.conn_str = connection_string
self._ensure_table()
def _ensure_table(self):
"""创建表结构"""
conn = dmPython.connect(self.conn_str)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS checkpoints (
thread_id VARCHAR(100),
checkpoint_id VARCHAR(100),
checkpoint_data TEXT,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (thread_id, checkpoint_id)
)
""")
conn.commit()
conn.close()
def put(self, config, checkpoint, metadata):
"""保存检查点"""
thread_id = config["configurable"]["thread_id"]
checkpoint_id = checkpoint.get("id", "")
conn = dmPython.connect(self.conn_str)
cursor = conn.cursor()
# 达梦使用MERGE或先DELETE再INSERT
cursor.execute("""
DELETE FROM checkpoints
WHERE thread_id = ? AND checkpoint_id = ?
""", (thread_id, checkpoint_id))
cursor.execute("""
INSERT INTO checkpoints
(thread_id, checkpoint_id, checkpoint_data, metadata)
VALUES (?, ?, ?, ?)
""", (
thread_id,
checkpoint_id,
json.dumps(checkpoint, ensure_ascii=False),
json.dumps(metadata, ensure_ascii=False)
))
conn.commit()
conn.close()
def get(self, config):
"""获取最新检查点"""
thread_id = config["configurable"]["thread_id"]
conn = dmPython.connect(self.conn_str)
cursor = conn.cursor()
cursor.execute("""
SELECT checkpoint_data
FROM checkpoints
WHERE thread_id = ?
ORDER BY created_at DESC
LIMIT 1
""", (thread_id,))
result = cursor.fetchone()
conn.close()
if result:
return json.loads(result[0])
return None
# 使用达梦数据库
dm_checkpointer = DmCheckpointer(
"dm://SYSDBA:SYSDBA@localhost:5236/TEST"
)
app = workflow.compile(checkpointer=dm_checkpointer)
---
b.数据库优化
a.功能说明
优化数据库检查点存储的性能,包括索引优化、连接池、批量操作、异步I/O等。使用JSONB字段支持高效的JSON查询,创建合适的索引加速检索。配置连接池复用连接,减少连接开销。数据库优化确保大规模使用下的性能。
b.代码示例
---
import psycopg2
from psycopg2 import pool
# 1. 连接池
connection_pool = psycopg2.pool.ThreadedConnectionPool(
minconn=2,
maxconn=10,
dsn="postgresql://user:pass@localhost/langchain"
)
class PooledPostgresCheckpointer(PostgresCheckpointer):
"""使用连接池的Checkpointer"""
def __init__(self, connection_pool):
self.pool = connection_pool
def put(self, config, checkpoint, metadata):
"""使用连接池保存"""
conn = self.pool.getconn()
try:
cursor = conn.cursor()
# ... 执行INSERT ...
conn.commit()
finally:
self.pool.putconn(conn)
# 2. 批量保存
def batch_save_checkpoints(checkpointer, checkpoints: list):
"""批量保存多个检查点"""
conn = psycopg2.connect(checkpointer.conn_str)
cursor = conn.cursor()
# 使用execute_batch提升性能
from psycopg2.extras import execute_batch
execute_batch(cursor, """
INSERT INTO checkpoints
(thread_id, checkpoint_id, checkpoint_data, metadata)
VALUES (%s, %s, %s, %s)
""", [
(cp["thread_id"], cp["id"], Json(cp["data"]), Json(cp["metadata"]))
for cp in checkpoints
])
conn.commit()
conn.close()
# 3. 异步PostgreSQL
import asyncpg
class AsyncPostgresCheckpointer:
"""异步PostgreSQL Checkpointer"""
def __init__(self, dsn: str):
self.dsn = dsn
self.pool = None
async def init_pool(self):
"""初始化连接池"""
self.pool = await asyncpg.create_pool(
self.dsn,
min_size=2,
max_size=10
)
async def put(self, config, checkpoint, metadata):
"""异步保存"""
thread_id = config["configurable"]["thread_id"]
checkpoint_id = checkpoint.get("id")
async with self.pool.acquire() as conn:
await conn.execute("""
INSERT INTO checkpoints
(thread_id, checkpoint_id, checkpoint_data, metadata)
VALUES ($1, $2, $3, $4)
ON CONFLICT (thread_id, checkpoint_id)
DO UPDATE SET
checkpoint_data = EXCLUDED.checkpoint_data,
metadata = EXCLUDED.metadata
""", thread_id, checkpoint_id, json.dumps(checkpoint), json.dumps(metadata))
async def get(self, config):
"""异步获取"""
thread_id = config["configurable"]["thread_id"]
async with self.pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT checkpoint_data
FROM checkpoints
WHERE thread_id = $1
ORDER BY created_at DESC
LIMIT 1
""", thread_id)
return json.loads(row["checkpoint_data"]) if row else None
# 4. 分区表优化
# 按月份分区
cursor.execute("""
CREATE TABLE checkpoints_2024_01 PARTITION OF checkpoints
FOR VALUES FROM ('2024-01-01') TO ('2024-02-01')
""")
cursor.execute("""
CREATE TABLE checkpoints_2024_02 PARTITION OF checkpoints
FOR VALUES FROM ('2024-02-01') TO ('2024-03-01')
""")
---
4.4 状态恢复
01.从检查点恢复
a.继续执行
a.功能说明
从保存的检查点恢复状态并继续执行。使用相同的thread_id和config调用invoke,Graph自动从最新检查点加载状态继续执行。适用于长时间任务中断恢复、人工审核后继续、错误修复后重试等场景。状态恢复确保工作流的连续性和可靠性。
b.代码示例
---
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
# 1. 基础恢复
checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
app = workflow.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "session_1"}}
# 首次执行(可能被中断)
try:
result = app.invoke({"input": "start task"}, config=config)
except KeyboardInterrupt:
print("执行被中断,状态已保存")
# 恢复并继续
result = app.invoke(None, config=config)
print(f"从检查点恢复,继续执行完成:{result}")
# 2. 检查是否有未完成的会话
def has_unfinished_session(thread_id: str) -> bool:
"""检查是否有未完成的会话"""
config = {"configurable": {"thread_id": thread_id}}
checkpoint = checkpointer.get(config)
if checkpoint:
# 检查是否已到达END
next_node = checkpoint.get("next", [])
return len(next_node) > 0 # 有待执行的节点
return False
# 恢复未完成的会话
if has_unfinished_session("user_123"):
print("发现未完成的会话,继续执行...")
result = app.invoke(
None,
config={"configurable": {"thread_id": "user_123"}}
)
# 3. 带输入的恢复
# 恢复并提供新输入
result = app.invoke(
{"additional_input": "new data"},
config={"configurable": {"thread_id": "session_1"}}
)
# 4. 部分恢复
def resume_from_node(thread_id: str, from_node: str):
"""从指定节点恢复"""
config = {"configurable": {"thread_id": thread_id}}
# 获取检查点
checkpoint = checkpointer.get(config)
if checkpoint:
# 修改next指针到指定节点
checkpoint["next"] = [from_node]
# 保存修改后的检查点
checkpointer.put(config, checkpoint, {})
# 继续执行
return app.invoke(None, config=config)
# 从特定节点恢复
result = resume_from_node("session_1", "retry_node")
# 5. 恢复验证
def validate_before_resume(thread_id: str) -> bool:
"""恢复前验证状态"""
config = {"configurable": {"thread_id": thread_id}}
checkpoint = checkpointer.get(config)
if not checkpoint:
print("没有可恢复的检查点")
return False
# 验证状态完整性
required_fields = ["input", "current_step"]
state = checkpoint.get("channel_values", {})
for field in required_fields:
if field not in state:
print(f"状态不完整,缺少字段:{field}")
return False
# 验证时间(不恢复过期的会话)
created_at = checkpoint.get("metadata", {}).get("created_at")
if created_at:
from datetime import datetime, timedelta
checkpoint_time = datetime.fromisoformat(created_at)
if datetime.now() - checkpoint_time > timedelta(hours=24):
print("检查点已过期(超过24小时)")
return False
return True
# 安全恢复
if validate_before_resume("session_1"):
result = app.invoke(None, config={"configurable": {"thread_id": "session_1"}})
---
b.时光旅行
a.功能说明
LangGraph支持回溯到历史检查点,实现时光旅行debugging。可以查看任意时刻的状态,从过去的检查点重新执行,对比不同路径的结果。适用于调试、实验、AB测试等场景。时光旅行帮助理解Graph的执行过程,优化工作流设计。
b.代码示例
---
from langgraph.graph import StateGraph
from langgraph.checkpoint.sqlite import SqliteSaver
checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
app = workflow.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "debug_session"}}
# 1. 执行并记录历史
result = app.invoke({"input": "task"}, config=config)
# 2. 列出所有历史检查点
history = list(checkpointer.list(config))
print("执行历史:")
for i, checkpoint in enumerate(history):
metadata = checkpoint.get("metadata", {})
print(f"{i+1}. 步骤: {metadata.get('step')}, 节点: {metadata.get('source')}")
# 3. 回到特定检查点
# 获取第3个检查点
checkpoint_3 = history[2]
checkpoint_id = checkpoint_3["checkpoint"]["id"]
# 从第3个检查点重新执行
result = app.invoke(
None,
config={
"configurable": {
"thread_id": "debug_session",
"checkpoint_id": checkpoint_id # 指定检查点
}
}
)
# 4. 对比不同路径
def compare_execution_paths(thread_id: str, checkpoint_id_1: str, checkpoint_id_2: str):
"""对比两个检查点的后续执行"""
# 路径1
result_1 = app.invoke(
None,
config={
"configurable": {
"thread_id": f"{thread_id}_path1",
"checkpoint_id": checkpoint_id_1
}
}
)
# 路径2
result_2 = app.invoke(
None,
config={
"configurable": {
"thread_id": f"{thread_id}_path2",
"checkpoint_id": checkpoint_id_2
}
}
)
return {
"path1": result_1,
"path2": result_2,
"diff": compute_diff(result_1, result_2)
}
# 5. 分支执行
def branch_from_checkpoint(thread_id: str, checkpoint_id: str, branch_name: str):
"""从检查点创建分支"""
# 复制检查点到新分支
original_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_id": checkpoint_id
}
}
original_checkpoint = checkpointer.get(original_config)
# 创建新分支
branch_config = {
"configurable": {
"thread_id": f"{thread_id}_branch_{branch_name}"
}
}
# 保存到新分支
checkpointer.put(branch_config, original_checkpoint, {})
# 从新分支执行
return app.invoke(None, config=branch_config)
# 创建分支并尝试不同策略
result_strategy_a = branch_from_checkpoint(
"session_1",
"checkpoint_5",
"strategy_a"
)
result_strategy_b = branch_from_checkpoint(
"session_1",
"checkpoint_5",
"strategy_b"
)
# 6. 调试工具
class CheckpointDebugger:
"""检查点调试工具"""
def __init__(self, app, checkpointer):
self.app = app
self.checkpointer = checkpointer
def replay_execution(self, thread_id: str):
"""重放完整执行过程"""
config = {"configurable": {"thread_id": thread_id}}
history = list(self.checkpointer.list(config))
print(f"重放执行(共{len(history)}步):\n")
for i, checkpoint in enumerate(reversed(history)):
state = checkpoint["checkpoint"].get("channel_values", {})
metadata = checkpoint.get("metadata", {})
print(f"步骤{i+1}:")
print(f" 节点: {metadata.get('source')}")
print(f" 状态: {state}")
print()
def find_divergence(self, thread_id_1: str, thread_id_2: str):
"""找到两次执行的分歧点"""
history_1 = list(self.checkpointer.list(
{"configurable": {"thread_id": thread_id_1}}
))
history_2 = list(self.checkpointer.list(
{"configurable": {"thread_id": thread_id_2}}
))
for i, (cp1, cp2) in enumerate(zip(reversed(history_1), reversed(history_2))):
if cp1["checkpoint"] != cp2["checkpoint"]:
return i, cp1, cp2
return None
debugger = CheckpointDebugger(app, checkpointer)
debugger.replay_execution("session_1")
---
5 人工介入
5.1 Human-in-the-Loop
01.中断机制
a.interrupt_before
a.功能说明
在指定节点执行前中断图的执行,等待人工输入或审批。使用interrupt_before参数配置中断点,Graph执行到该节点前会暂停并保存检查点。人工处理后调用invoke继续执行。适用于审批流程、风险决策、质量把关等需要人工判断的场景。
b.代码示例
---
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from typing import TypedDict
class ApprovalState(TypedDict):
request: str
analysis: str
approved: bool
final_result: str
# 1. 定义工作流
def analyze_request(state: ApprovalState):
"""分析请求"""
analysis = perform_analysis(state["request"])
return {"analysis": analysis}
def execute_request(state: ApprovalState):
"""执行请求(需要审批)"""
if not state.get("approved"):
return {"final_result": "未经审批,拒绝执行"}
result = execute(state["request"])
return {"final_result": result}
def notify_completion(state: ApprovalState):
"""通知完成"""
send_notification(state["final_result"])
return {}
# 2. 配置中断点
workflow = StateGraph(ApprovalState)
workflow.add_node("analyze", analyze_request)
workflow.add_node("execute", execute_request)
workflow.add_node("notify", notify_completion)
workflow.set_entry_point("analyze")
workflow.add_edge("analyze", "execute")
workflow.add_edge("execute", "notify")
workflow.add_edge("notify", END)
checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
# 在execute节点前中断
app = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["execute"] # 执行前需要审批
)
# 3. 执行到中断点
config = {"configurable": {"thread_id": "approval_001"}}
result = app.invoke(
{"request": "删除用户数据", "analysis": "", "approved": False, "final_result": ""},
config=config
)
print("执行已暂停,等待审批")
print(f"分析结果:{result['analysis']}")
# 4. 人工审批
# 在UI或命令行获取审批决定
approval_decision = input("是否批准?(y/n): ")
# 更新状态
approved = (approval_decision.lower() == 'y')
# 5. 继续执行
final_result = app.invoke(
{"approved": approved}, # 提供审批结果
config=config
)
print(f"执行完成:{final_result['final_result']}")
# 6. 多中断点
app_multi_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["execute", "notify"] # 两个中断点
)
# 第一次执行到第一个中断点
result1 = app_multi_interrupt.invoke({"request": "..."}, config=config)
# 审批后继续,执行到第二个中断点
result2 = app_multi_interrupt.invoke({"approved": True}, config=config)
# 再次确认后完成
final = app_multi_interrupt.invoke({}, config=config)
---
b.interrupt_after
a.功能说明
在指定节点执行后中断,用于检查节点输出、验证结果、人工修正等。节点已执行完毕,状态已更新并保存检查点,可以查看和修改状态后继续。interrupt_after适用于结果验证、数据审核、输出调整等场景。
b.代码示例
---
from langgraph.graph import StateGraph, END
class ReviewState(TypedDict):
draft: str
review_comments: str
approved: bool
final_version: str
# 1. 生成草稿节点
def generate_draft(state: ReviewState):
"""生成初稿"""
draft = create_document(state.get("requirements", ""))
return {"draft": draft}
def finalize_document(state: ReviewState):
"""最终定稿"""
if state.get("approved"):
final = state["draft"]
else:
# 根据审核意见修改
final = revise_document(state["draft"], state["review_comments"])
return {"final_version": final}
# 2. 在draft生成后中断
workflow = StateGraph(ReviewState)
workflow.add_node("draft", generate_draft)
workflow.add_node("finalize", finalize_document)
workflow.set_entry_point("draft")
workflow.add_edge("draft", "finalize")
workflow.add_edge("finalize", END)
app = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["draft"] # 生成草稿后中断审核
)
# 3. 执行并中断
config = {"configurable": {"thread_id": "doc_001"}}
result = app.invoke(
{"draft": "", "review_comments": "", "approved": False, "final_version": ""},
config=config
)
print("草稿已生成,等待审核")
print(f"草稿内容:{result['draft']}")
# 4. 人工审核
draft_content = result["draft"]
# 人工阅读和评论
comments = input("审核意见:")
approved = input("是否批准?(y/n): ") == 'y'
# 5. 继续执行
final_result = app.invoke(
{
"review_comments": comments,
"approved": approved
},
config=config
)
print(f"最终版本:{final_result['final_version']}")
# 6. 修改状态后继续
# 获取当前检查点
checkpoint = checkpointer.get(config)
current_state = checkpoint["channel_values"]
# 人工修改状态
current_state["draft"] = "人工修正后的草稿..."
# 保存修改
checkpointer.put(
config,
{**checkpoint, "channel_values": current_state},
checkpoint.get("metadata", {})
)
# 继续执行
final_result = app.invoke({"approved": True}, config=config)
---
02.人工决策
a.审批流程
a.功能说明
实现多级审批流程,在关键节点需要人工审批才能继续。使用interrupt机制暂停执行,通过状态字段传递审批决定。支持审批意见、条件审批、回退修改等功能。审批流程确保关键操作的合规性和安全性。
b.代码示例
---
from typing import TypedDict, Literal
class ApprovalWorkflowState(TypedDict):
request_id: str
request_content: str
requester: str
# 初审
initial_review: str
initial_approved: bool
# 复审
secondary_review: str
secondary_approved: bool
# 终审
final_review: str
final_approved: bool
# 结果
approval_status: Literal["pending", "approved", "rejected"]
comments: list
# 1. 初审节点
def initial_review_node(state: ApprovalWorkflowState):
"""初审:自动规则检查"""
content = state["request_content"]
# 自动检查
auto_check = run_auto_validation(content)
if not auto_check["passed"]:
return {
"initial_review": auto_check["reason"],
"initial_approved": False,
"approval_status": "rejected",
"comments": [f"初审不通过:{auto_check['reason']}"]
}
return {
"initial_review": "自动检查通过,等待人工审批",
"comments": ["初审:自动检查通过"]
}
# 2. 人工审批节点(在此中断)
def human_approval_node(state: ApprovalWorkflowState):
"""人工审批节点"""
# 此节点在interrupt_before中断,等待人工输入
if state.get("initial_approved"):
return {
"comments": state["comments"] + ["初审批准"],
"approval_status": "pending"
}
else:
return {
"approval_status": "rejected",
"comments": state["comments"] + ["初审拒绝"]
}
# 3. 复审节点
def secondary_review_node(state: ApprovalWorkflowState):
"""复审节点"""
# 需要更高级别的审批
if state.get("secondary_approved"):
return {
"comments": state["comments"] + ["复审批准"],
"approval_status": "pending"
}
else:
return {
"approval_status": "rejected",
"comments": state["comments"] + ["复审拒绝"]
}
# 4. 终审节点
def final_approval_node(state: ApprovalWorkflowState):
"""终审节点"""
if state.get("final_approved"):
return {
"approval_status": "approved",
"comments": state["comments"] + ["终审批准,流程完成"]
}
else:
return {
"approval_status": "rejected",
"comments": state["comments"] + ["终审拒绝"]
}
# 5. 构建审批流程图
workflow = StateGraph(ApprovalWorkflowState)
workflow.add_node("initial", initial_review_node)
workflow.add_node("human_initial", human_approval_node)
workflow.add_node("secondary", secondary_review_node)
workflow.add_node("final", final_approval_node)
workflow.set_entry_point("initial")
workflow.add_edge("initial", "human_initial")
# 初审后的路由
def after_initial(state):
return "next" if state.get("initial_approved") else "end"
workflow.add_conditional_edges(
"human_initial",
after_initial,
{"next": "secondary", "end": END}
)
workflow.add_edge("secondary", "final")
workflow.add_edge("final", END)
# 配置多个中断点
app = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["human_initial", "secondary", "final"]
)
# 6. 执行审批流程
config = {"configurable": {"thread_id": "request_001"}}
# 提交申请
result = app.invoke({
"request_id": "REQ001",
"request_content": "申请内容...",
"requester": "张三",
"approval_status": "pending",
"comments": []
}, config=config)
# 初审人员审批
result = app.invoke({"initial_approved": True}, config=config)
# 复审人员审批
result = app.invoke({"secondary_approved": True}, config=config)
# 终审人员审批
final_result = app.invoke({"final_approved": True}, config=config)
print(f"审批状态:{final_result['approval_status']}")
print("审批历史:")
for comment in final_result["comments"]:
print(f" - {comment}")
---
5.2 中断与恢复
01.中断控制
a.主动中断
a.功能说明
在节点内部根据条件主动触发中断,而非预定义中断点。通过抛出特殊异常或返回中断标记实现。主动中断更灵活,可以根据运行时状态动态决定是否需要人工介入。适用于异常检测、置信度过低、资源不足等动态中断场景。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict
class DynamicInterruptState(TypedDict):
input: str
confidence: float
result: str
needs_human: bool
# 1. 基于置信度的动态中断
def analyze_with_confidence(state: DynamicInterruptState):
"""分析并计算置信度"""
result = perform_analysis(state["input"])
confidence = result["confidence"]
# 置信度过低,标记需要人工
if confidence < 0.7:
return {
"result": result["data"],
"confidence": confidence,
"needs_human": True
}
else:
return {
"result": result["data"],
"confidence": confidence,
"needs_human": False
}
def check_if_needs_human(state: DynamicInterruptState) -> str:
"""检查是否需要人工"""
if state.get("needs_human"):
return "human_review"
else:
return "auto_complete"
workflow = StateGraph(DynamicInterruptState)
workflow.add_node("analyze", analyze_with_confidence)
workflow.add_node("human_review", lambda s: s) # 人工审核节点
workflow.add_node("auto_complete", lambda s: {"result": "自动完成"})
workflow.set_entry_point("analyze")
workflow.add_conditional_edges(
"analyze",
check_if_needs_human,
{
"human_review": "human_review",
"auto_complete": "auto_complete"
}
)
workflow.add_edge("human_review", END)
workflow.add_edge("auto_complete", END)
app = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["human_review"] # 只有需要时才中断
)
# 2. 基于异常的中断
class RequiresHumanException(Exception):
"""需要人工介入异常"""
def __init__(self, reason: str, data: dict):
self.reason = reason
self.data = data
super().__init__(reason)
def node_with_exception_interrupt(state: State):
"""可能抛出异常的节点"""
try:
result = risky_operation(state["input"])
# 检查结果
if not validate_result(result):
raise RequiresHumanException(
"结果验证失败,需要人工检查",
{"result": result, "validation": "failed"}
)
return {"output": result}
except RequiresHumanException as e:
# 标记需要人工
return {
"output": str(e.data),
"needs_human": True,
"interrupt_reason": e.reason
}
# 3. 多条件中断
def complex_interrupt_check(state: State) -> dict:
"""复杂的中断判断"""
reasons = []
# 检查多个条件
if state.get("confidence", 1.0) < 0.6:
reasons.append("置信度过低")
if state.get("cost") > 1000:
reasons.append("成本超限")
if len(state.get("errors", [])) > 0:
reasons.append("存在错误")
return {
"needs_human": len(reasons) > 0,
"interrupt_reasons": reasons
}
# 4. 超时中断
import signal
class TimeoutInterrupt(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutInterrupt("执行超时")
def node_with_timeout(state: State, timeout_seconds: int = 30):
"""带超时的节点"""
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout_seconds)
try:
result = long_running_operation(state["input"])
signal.alarm(0)
return {"output": result}
except TimeoutInterrupt:
signal.alarm(0)
return {
"output": "超时",
"needs_human": True,
"interrupt_reason": "操作超时,需要人工处理"
}
---
02.恢复策略
a.状态修改恢复
a.功能说明
中断后可以修改状态再恢复执行,实现人工修正、数据补充、参数调整等。获取检查点、修改状态字段、保存检查点、继续执行。状态修改恢复让人类能够干预和优化AI的执行过程,提升最终结果质量。
b.代码示例
---
from langgraph.checkpoint.sqlite import SqliteSaver
import json
# 1. 获取并修改状态
config = {"configurable": {"thread_id": "session_1"}}
# 获取当前检查点
checkpoint = checkpointer.get(config)
current_state = checkpoint["channel_values"]
print("当前状态:", current_state)
# 2. 人工修改
# 修改分析结果
current_state["analysis"] = "人工修正后的分析..."
# 修改置信度
current_state["confidence"] = 0.95
# 添加人工备注
current_state["human_notes"] = "已人工审核并修正"
# 3. 保存修改后的状态
updated_checkpoint = checkpoint.copy()
updated_checkpoint["channel_values"] = current_state
checkpointer.put(
config,
updated_checkpoint,
{**checkpoint.get("metadata", {}), "modified_by": "human"}
)
# 4. 继续执行
result = app.invoke(None, config=config)
# 5. 状态修改工具
class StateEditor:
"""状态编辑工具"""
def __init__(self, app, checkpointer):
self.app = app
self.checkpointer = checkpointer
def get_state(self, thread_id: str) -> dict:
"""获取当前状态"""
config = {"configurable": {"thread_id": thread_id}}
checkpoint = self.checkpointer.get(config)
return checkpoint["channel_values"] if checkpoint else {}
def update_field(self, thread_id: str, field: str, value):
"""更新单个字段"""
config = {"configurable": {"thread_id": thread_id}}
checkpoint = self.checkpointer.get(config)
if checkpoint:
state = checkpoint["channel_values"]
state[field] = value
updated_checkpoint = checkpoint.copy()
updated_checkpoint["channel_values"] = state
self.checkpointer.put(config, updated_checkpoint, checkpoint.get("metadata", {}))
def resume(self, thread_id: str, additional_input: dict = None):
"""恢复执行"""
config = {"configurable": {"thread_id": thread_id}}
return self.app.invoke(additional_input, config=config)
# 使用状态编辑器
editor = StateEditor(app, checkpointer)
# 查看状态
state = editor.get_state("session_1")
print(state)
# 修改字段
editor.update_field("session_1", "approved", True)
editor.update_field("session_1", "comments", "人工批准")
# 恢复执行
result = editor.resume("session_1")
# 6. 回退到上一步
def rollback_to_previous(thread_id: str):
"""回退到上一个检查点"""
config = {"configurable": {"thread_id": thread_id}}
history = list(checkpointer.list(config))
if len(history) < 2:
print("没有更早的检查点")
return None
# 获取倒数第二个检查点
previous = history[1]
previous_id = previous["checkpoint"]["id"]
# 从该检查点恢复
return app.invoke(
None,
config={
"configurable": {
"thread_id": thread_id,
"checkpoint_id": previous_id
}
}
)
---
5.3 人工审核节点
01.审核节点设计
a.审核逻辑
a.功能说明
设计专门的人工审核节点,封装审核逻辑、状态检查、决策接收等功能。审核节点在interrupt机制中断前等待人工输入,接收审批决定后更新状态。可以配置审核权限、超时机制、默认决策等。审核节点是Human-in-the-Loop的核心组件。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict, Literal
from datetime import datetime, timedelta
class ReviewState(TypedDict):
content: str
reviewer: str
review_result: Literal["pending", "approved", "rejected"]
review_comments: str
reviewed_at: str
# 1. 基础审核节点
def review_node(state: ReviewState):
"""人工审核节点"""
# 此节点在interrupt_before中断
# 执行到这里表示已收到审批决定
result = state.get("review_result", "pending")
if result == "approved":
return {
"review_result": "approved",
"reviewed_at": datetime.now().isoformat()
}
elif result == "rejected":
return {
"review_result": "rejected",
"reviewed_at": datetime.now().isoformat()
}
else:
# 仍然pending,继续等待
return {}
# 2. 带权限检查的审核节点
def privileged_review_node(state: ReviewState):
"""带权限检查的审核节点"""
reviewer = state.get("reviewer", "")
# 检查审核权限
if not has_review_permission(reviewer, state.get("content_type", "")):
return {
"review_result": "rejected",
"review_comments": "审核人员无权限",
"reviewed_at": datetime.now().isoformat()
}
# 记录审核人
return {
"reviewer": reviewer,
"reviewed_at": datetime.now().isoformat()
}
def has_review_permission(reviewer: str, content_type: str) -> bool:
"""检查审核权限"""
permissions = {
"manager": ["document", "report"],
"admin": ["document", "report", "code", "data"],
"auditor": ["report", "financial"]
}
user_role = get_user_role(reviewer)
return content_type in permissions.get(user_role, [])
# 3. 多级审核节点
class MultiLevelReviewState(TypedDict):
content: str
level1_reviewer: str
level1_approved: bool
level1_comments: str
level2_reviewer: str
level2_approved: bool
level2_comments: str
final_approved: bool
def level1_review_node(state: MultiLevelReviewState):
"""一级审核"""
if state.get("level1_approved"):
return {
"level1_approved": True,
"level1_comments": state.get("level1_comments", "一级审核通过")
}
else:
return {
"level1_approved": False,
"final_approved": False
}
def level2_review_node(state: MultiLevelReviewState):
"""二级审核(只有一级通过才执行)"""
if state.get("level2_approved"):
return {
"level2_approved": True,
"level2_comments": state.get("level2_comments", "二级审核通过"),
"final_approved": True
}
else:
return {
"level2_approved": False,
"final_approved": False
}
# 构建多级审核图
workflow = StateGraph(MultiLevelReviewState)
workflow.add_node("level1", level1_review_node)
workflow.add_node("level2", level2_review_node)
workflow.set_entry_point("level1")
# 一级审核后路由
def after_level1(state):
return "level2" if state.get("level1_approved") else "end"
workflow.add_conditional_edges(
"level1",
after_level1,
{"level2": "level2", "end": END}
)
workflow.add_edge("level2", END)
# 配置两个中断点
app = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["level1", "level2"]
)
# 4. 超时审核节点
class TimeoutReviewState(TypedDict):
content: str
review_deadline: str
review_result: str
auto_approved_on_timeout: bool
def timeout_review_node(state: TimeoutReviewState):
"""带超时的审核节点"""
deadline_str = state.get("review_deadline", "")
if deadline_str:
deadline = datetime.fromisoformat(deadline_str)
# 检查是否超时
if datetime.now() > deadline:
# 超时处理
if state.get("auto_approved_on_timeout"):
return {
"review_result": "approved",
"review_comments": "审核超时,自动批准",
"reviewed_at": datetime.now().isoformat()
}
else:
return {
"review_result": "rejected",
"review_comments": "审核超时,自动拒绝",
"reviewed_at": datetime.now().isoformat()
}
# 未超时,正常审核
return {}
# 5. 批量审核节点
class BatchReviewState(TypedDict):
items: list # 待审核项目列表
reviewed_items: list # 已审核项目
current_index: int
all_approved: bool
def batch_review_node(state: BatchReviewState):
"""批量审核节点"""
items = state.get("items", [])
reviewed = state.get("reviewed_items", [])
current_index = state.get("current_index", 0)
if current_index < len(items):
# 当前项目的审核结果
current_item = items[current_index]
approved = current_item.get("approved", False)
reviewed.append({
**current_item,
"approved": approved,
"reviewed_at": datetime.now().isoformat()
})
return {
"reviewed_items": reviewed,
"current_index": current_index + 1
}
else:
# 所有项目审核完成
all_approved = all([item.get("approved") for item in reviewed])
return {
"all_approved": all_approved
}
# 6. 智能审核节点
from langchain.chat_models import ChatOpenAI
def ai_assisted_review_node(state: ReviewState):
"""AI辅助的人工审核节点"""
content = state.get("content", "")
# AI预审
llm = ChatOpenAI(model="gpt-4")
pre_review = llm.predict(f"""
分析以下内容,给出审核建议:
内容:{content}
请评估:
1. 是否符合规范
2. 是否存在风险
3. 建议批准或拒绝及理由
只返回JSON格式:{{"recommendation": "approve/reject", "reason": "理由", "risk_level": "low/medium/high"}}
""")
# 解析AI建议
import json
ai_suggestion = json.loads(pre_review)
# 将AI建议作为参考提供给人工审核
return {
"ai_recommendation": ai_suggestion["recommendation"],
"ai_reason": ai_suggestion["reason"],
"risk_level": ai_suggestion["risk_level"]
}
---
02.审核流程
a.审批工作流
a.功能说明
构建完整的审批工作流,包括提交、预审、人工审核、后处理等环节。使用条件边根据审核结果路由到不同后续节点。支持审核拒绝后的重新提交、修改、上诉等流程。审批工作流是企业应用中常见的业务场景。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict, Literal
class ApprovalWorkflow(TypedDict):
# 申请信息
request_id: str
requester: str
request_type: str
request_content: str
# 预审
pre_check_passed: bool
pre_check_message: str
# 审核
reviewer: str
review_result: Literal["pending", "approved", "rejected", "revision_required"]
review_comments: str
# 修改
revision_count: int
revised_content: str
# 最终状态
final_status: Literal["approved", "rejected", "cancelled"]
approval_history: list
# 1. 预审节点
def pre_check_node(state: ApprovalWorkflow):
"""自动预审"""
content = state["request_content"]
request_type = state["request_type"]
# 执行自动检查
checks = [
validate_format(content),
check_completeness(content),
check_permissions(state["requester"], request_type)
]
if all(checks):
return {
"pre_check_passed": True,
"pre_check_message": "预审通过",
"approval_history": [
{"step": "pre_check", "result": "passed", "time": datetime.now().isoformat()}
]
}
else:
return {
"pre_check_passed": False,
"pre_check_message": "预审不通过:" + get_failure_reason(checks),
"final_status": "rejected"
}
# 2. 分配审核人
def assign_reviewer_node(state: ApprovalWorkflow):
"""分配审核人"""
request_type = state["request_type"]
# 根据请求类型分配审核人
reviewer = find_appropriate_reviewer(request_type)
return {
"reviewer": reviewer,
"approval_history": state.get("approval_history", []) + [
{"step": "assign", "reviewer": reviewer, "time": datetime.now().isoformat()}
]
}
# 3. 人工审核节点
def human_review_node(state: ApprovalWorkflow):
"""人工审核(interrupt点)"""
result = state.get("review_result", "pending")
history_entry = {
"step": "review",
"reviewer": state.get("reviewer", ""),
"result": result,
"comments": state.get("review_comments", ""),
"time": datetime.now().isoformat()
}
return {
"approval_history": state.get("approval_history", []) + [history_entry]
}
# 4. 修改节点
def revision_node(state: ApprovalWorkflow):
"""申请人修改内容"""
revised = state.get("revised_content", "")
if revised:
return {
"request_content": revised,
"revision_count": state.get("revision_count", 0) + 1,
"review_result": "pending", # 重置审核状态
"approval_history": state.get("approval_history", []) + [
{"step": "revision", "count": state.get("revision_count", 0) + 1, "time": datetime.now().isoformat()}
]
}
else:
# 放弃修改
return {
"final_status": "cancelled"
}
# 5. 最终处理节点
def finalize_node(state: ApprovalWorkflow):
"""最终处理"""
result = state.get("review_result")
if result == "approved":
# 执行批准后的操作
execute_approved_request(state)
return {
"final_status": "approved",
"approval_history": state.get("approval_history", []) + [
{"step": "finalize", "status": "approved", "time": datetime.now().isoformat()}
]
}
else:
return {
"final_status": "rejected",
"approval_history": state.get("approval_history", []) + [
{"step": "finalize", "status": "rejected", "time": datetime.now().isoformat()}
]
}
# 6. 构建工作流
workflow = StateGraph(ApprovalWorkflow)
workflow.add_node("pre_check", pre_check_node)
workflow.add_node("assign", assign_reviewer_node)
workflow.add_node("review", human_review_node)
workflow.add_node("revision", revision_node)
workflow.add_node("finalize", finalize_node)
workflow.set_entry_point("pre_check")
# 预审后路由
def after_pre_check(state):
return "assign" if state.get("pre_check_passed") else "end"
workflow.add_conditional_edges(
"pre_check",
after_pre_check,
{"assign": "assign", "end": END}
)
workflow.add_edge("assign", "review")
# 审核后路由
def after_review(state):
result = state.get("review_result")
if result == "approved":
return "finalize"
elif result == "revision_required":
# 检查修改次数
if state.get("revision_count", 0) < 3:
return "revision"
else:
return "finalize" # 超过3次修改,拒绝
else:
return "finalize"
workflow.add_conditional_edges(
"review",
after_review,
{
"finalize": "finalize",
"revision": "revision"
}
)
# 修改后回到分配(重新审核)
workflow.add_edge("revision", "assign")
workflow.add_edge("finalize", END)
# 配置中断点
app = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["review", "revision"]
)
# 7. 执行审批流程
config = {"configurable": {"thread_id": "request_001"}}
# 提交申请
result = app.invoke({
"request_id": "REQ001",
"requester": "张三",
"request_type": "资源申请",
"request_content": "申请服务器资源...",
"revision_count": 0,
"approval_history": []
}, config=config)
# 审核人审核
result = app.invoke({
"review_result": "revision_required",
"review_comments": "请补充预算说明"
}, config=config)
# 申请人修改
result = app.invoke({
"revised_content": "修改后的申请内容,包含预算..."
}, config=config)
# 再次审核
final_result = app.invoke({
"review_result": "approved",
"review_comments": "批准"
}, config=config)
print(f"最终状态:{final_result['final_status']}")
print("审批历史:")
for entry in final_result["approval_history"]:
print(f" {entry}")
---
6 流式处理
6.1 流式执行
01.stream方法
a.基础流式
a.功能说明
使用stream()方法流式执行图,逐个节点yield输出而非等待全部完成。每个节点执行完立即返回结果,实现实时响应。适用于长时间任务、进度展示、用户交互等场景。流式执行提升用户体验,降低感知延迟。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict
class StreamState(TypedDict):
input: str
step1_result: str
step2_result: str
final_result: str
# 1. 定义节点
def step1_node(state: StreamState):
"""步骤1:耗时操作"""
import time
time.sleep(2) # 模拟耗时
return {"step1_result": "步骤1完成"}
def step2_node(state: StreamState):
"""步骤2:耗时操作"""
import time
time.sleep(2)
return {"step2_result": "步骤2完成"}
def final_node(state: StreamState):
"""最终处理"""
return {"final_result": "全部完成"}
# 2. 构建图
workflow = StateGraph(StreamState)
workflow.add_node("step1", step1_node)
workflow.add_node("step2", step2_node)
workflow.add_node("final", final_node)
workflow.set_entry_point("step1")
workflow.add_edge("step1", "step2")
workflow.add_edge("step2", "final")
workflow.add_edge("final", END)
app = workflow.compile()
# 3. 非流式执行(等待全部完成)
import time
start = time.time()
result = app.invoke({"input": "test"})
print(f"总耗时:{time.time() - start:.2f}秒")
print(f"结果:{result}")
# 4. 流式执行(逐个节点返回)
start = time.time()
for chunk in app.stream({"input": "test"}):
elapsed = time.time() - start
print(f"[{elapsed:.2f}s] 收到输出:{chunk}")
# 输出示例:
# [2.01s] 收到输出:{'step1': {'step1_result': '步骤1完成'}}
# [4.02s] 收到输出:{'step2': {'step2_result': '步骤2完成'}}
# [4.03s] 收到输出:{'final': {'final_result': '全部完成'}}
# 5. 流式执行with进度条
from rich.progress import Progress, SpinnerColumn, TextColumn
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
transient=True
) as progress:
task = progress.add_task("执行中...", total=None)
for chunk in app.stream({"input": "test"}):
node_name = list(chunk.keys())[0]
progress.update(task, description=f"完成节点:{node_name}")
print("执行完成!")
# 6. 获取最终结果
final_state = None
for chunk in app.stream({"input": "test"}):
# 合并所有chunk
if final_state is None:
final_state = {}
for node_output in chunk.values():
final_state.update(node_output)
print(f"最终状态:{final_state}")
---
b.astream异步流式
a.功能说明
使用astream()异步流式执行,配合asyncio实现高并发。异步流式适用于多用户、高并发、WebSocket等场景。每个chunk异步yield,不阻塞事件循环。astream提升系统吞吐量和响应性能。
b.代码示例
---
import asyncio
from langgraph.graph import StateGraph, END
# 1. 异步节点
async def async_step1(state: StreamState):
"""异步步骤1"""
await asyncio.sleep(2)
return {"step1_result": "异步步骤1完成"}
async def async_step2(state: StreamState):
"""异步步骤2"""
await asyncio.sleep(2)
return {"step2_result": "异步步骤2完成"}
# 2. 异步流式执行
async def run_async_stream():
"""异步流式执行"""
async for chunk in app.astream({"input": "test"}):
print(f"收到chunk:{chunk}")
# 运行
asyncio.run(run_async_stream())
# 3. 多个并发流式任务
async def concurrent_streams():
"""并发执行多个流式任务"""
tasks = []
for i in range(5):
async def run_stream(task_id):
async for chunk in app.astream({"input": f"task_{task_id}"}):
print(f"任务{task_id}:{chunk}")
tasks.append(run_stream(i))
await asyncio.gather(*tasks)
asyncio.run(concurrent_streams())
# 4. WebSocket流式输出
from fastapi import FastAPI, WebSocket
import json
app_api = FastAPI()
@app_api.websocket("/stream")
async def websocket_stream(websocket: WebSocket):
"""WebSocket流式输出"""
await websocket.accept()
try:
# 接收输入
data = await websocket.receive_json()
input_data = data.get("input", "")
# 流式执行
async for chunk in app.astream({"input": input_data}):
# 发送每个chunk
await websocket.send_json({
"type": "chunk",
"data": chunk
})
# 完成信号
await websocket.send_json({
"type": "complete"
})
except Exception as e:
await websocket.send_json({
"type": "error",
"message": str(e)
})
finally:
await websocket.close()
# 5. SSE (Server-Sent Events) 流式输出
from fastapi.responses import StreamingResponse
@app_api.get("/stream-sse")
async def sse_stream(input: str):
"""SSE流式输出"""
async def event_generator():
async for chunk in app.astream({"input": input}):
# SSE格式
yield f"data: {json.dumps(chunk)}\n\n"
# 结束信号
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream"
)
# 6. 异步批处理
async def async_batch_stream(inputs: list):
"""异步批处理流式"""
tasks = []
for input_data in inputs:
async def process_one(data):
results = []
async for chunk in app.astream({"input": data}):
results.append(chunk)
return results
tasks.append(process_one(input_data))
# 并发执行所有任务
all_results = await asyncio.gather(*tasks)
return all_results
# 执行
results = asyncio.run(async_batch_stream(["input1", "input2", "input3"]))
---
02.流式控制
a.chunk格式
a.功能说明
stream()返回的chunk是包含节点名称和输出的字典,格式为{node_name: node_output}。每个节点执行完返回一个chunk,可以根据节点名识别进度。chunk包含节点的状态更新,需要合并才能获得完整状态。理解chunk格式是正确处理流式输出的基础。
b.代码示例
---
from langgraph.graph import StateGraph, END
# 1. 解析chunk
for chunk in app.stream({"input": "test"}):
# chunk格式:{node_name: {field1: value1, field2: value2}}
for node_name, node_output in chunk.items():
print(f"节点:{node_name}")
print(f"输出:{node_output}")
# 访问具体字段
if "step1_result" in node_output:
print(f"步骤1结果:{node_output['step1_result']}")
# 2. 累积状态
class StateAccumulator:
"""状态累积器"""
def __init__(self):
self.state = {}
def update(self, chunk: dict):
"""更新状态"""
for node_output in chunk.values():
self.state.update(node_output)
def get_state(self):
"""获取当前完整状态"""
return self.state.copy()
# 使用累积器
accumulator = StateAccumulator()
for chunk in app.stream({"input": "test"}):
accumulator.update(chunk)
current_state = accumulator.get_state()
print(f"当前状态:{current_state}")
final_state = accumulator.get_state()
print(f"最终状态:{final_state}")
# 3. 过滤特定节点
def filter_nodes(stream, node_names: list):
"""只输出特定节点的chunk"""
for chunk in stream:
for node_name in node_names:
if node_name in chunk:
yield {node_name: chunk[node_name]}
# 只关注step1和final节点
for chunk in filter_nodes(app.stream({"input": "test"}), ["step1", "final"]):
print(chunk)
# 4. chunk转换
def transform_chunks(stream):
"""转换chunk格式"""
for chunk in stream:
# 转换为扁平格式
flat_chunk = {}
for node_name, node_output in chunk.items():
for key, value in node_output.items():
flat_chunk[f"{node_name}_{key}"] = value
yield flat_chunk
for transformed in transform_chunks(app.stream({"input": "test"})):
print(transformed)
# 输出:{'step1_step1_result': '步骤1完成'}
# 5. 带元数据的chunk
def add_metadata(stream):
"""为chunk添加元数据"""
import time
start_time = time.time()
chunk_index = 0
for chunk in stream:
chunk_index += 1
yield {
"index": chunk_index,
"elapsed": time.time() - start_time,
"data": chunk
}
for enriched_chunk in add_metadata(app.stream({"input": "test"})):
print(f"Chunk #{enriched_chunk['index']} "
f"at {enriched_chunk['elapsed']:.2f}s: "
f"{enriched_chunk['data']}")
---
b.流式缓冲
a.功能说明
流式输出可能产生大量小chunk,需要缓冲聚合后再输出。设置缓冲区收集chunk,定时或达到阈值时批量输出。流式缓冲减少I/O次数,提升传输效率,适用于网络传输、日志记录等场景。
b.代码示例
---
from collections import deque
import time
# 1. 时间缓冲
class TimeBufferedStream:
"""时间缓冲流"""
def __init__(self, stream, buffer_seconds: float = 1.0):
self.stream = stream
self.buffer_seconds = buffer_seconds
self.buffer = []
self.last_flush = time.time()
def __iter__(self):
for chunk in self.stream:
self.buffer.append(chunk)
# 检查是否需要flush
if time.time() - self.last_flush >= self.buffer_seconds:
yield self.buffer.copy()
self.buffer.clear()
self.last_flush = time.time()
# 输出剩余buffer
if self.buffer:
yield self.buffer
# 使用时间缓冲
buffered = TimeBufferedStream(app.stream({"input": "test"}), buffer_seconds=2.0)
for chunk_batch in buffered:
print(f"收到{len(chunk_batch)}个chunk:{chunk_batch}")
# 2. 大小缓冲
class SizeBufferedStream:
"""大小缓冲流"""
def __init__(self, stream, buffer_size: int = 5):
self.stream = stream
self.buffer_size = buffer_size
self.buffer = []
def __iter__(self):
for chunk in self.stream:
self.buffer.append(chunk)
if len(self.buffer) >= self.buffer_size:
yield self.buffer.copy()
self.buffer.clear()
if self.buffer:
yield self.buffer
# 3. 智能缓冲
class SmartBuffer:
"""智能缓冲(结合时间和大小)"""
def __init__(self, stream, max_size: int = 10, max_time: float = 2.0):
self.stream = stream
self.max_size = max_size
self.max_time = max_time
self.buffer = []
self.last_flush = time.time()
def __iter__(self):
for chunk in self.stream:
self.buffer.append(chunk)
# 达到大小限制或时间限制
should_flush = (
len(self.buffer) >= self.max_size or
time.time() - self.last_flush >= self.max_time
)
if should_flush:
yield self.buffer.copy()
self.buffer.clear()
self.last_flush = time.time()
if self.buffer:
yield self.buffer
# 4. 压缩缓冲
import gzip
import json
class CompressedBuffer:
"""压缩缓冲"""
def __init__(self, stream, buffer_size: int = 100):
self.stream = stream
self.buffer_size = buffer_size
self.buffer = []
def __iter__(self):
for chunk in self.stream:
self.buffer.append(chunk)
if len(self.buffer) >= self.buffer_size:
# 序列化并压缩
data = json.dumps(self.buffer).encode()
compressed = gzip.compress(data)
yield compressed
self.buffer.clear()
if self.buffer:
data = json.dumps(self.buffer).encode()
yield gzip.compress(data)
---
6.2 事件监听
01.astream_events
a.事件类型
a.功能说明
astream_events()提供更细粒度的事件流,包括节点开始、结束、LLM token、工具调用等。每个事件包含类型、数据、元数据。支持on_chain_start、on_chain_end、on_llm_start、on_llm_new_token、on_tool_start等事件。事件监听实现精细化的进度跟踪和调试。
b.代码示例
---
from langgraph.graph import StateGraph, END
import asyncio
# 1. 监听所有事件
async def listen_all_events():
"""监听所有事件"""
async for event in app.astream_events({"input": "test"}, version="v1"):
event_type = event["event"]
data = event.get("data", {})
name = event.get("name", "")
print(f"事件:{event_type}")
print(f" 名称:{name}")
print(f" 数据:{data}")
print()
asyncio.run(listen_all_events())
# 2. 过滤特定事件类型
async def listen_llm_tokens():
"""只监听LLM token事件"""
async for event in app.astream_events({"input": "test"}, version="v1"):
if event["event"] == "on_llm_new_token":
token = event["data"]["chunk"]
print(token, end="", flush=True)
# 3. 事件类型示例
# on_chain_start: 链/节点开始
# on_chain_end: 链/节点结束
# on_llm_start: LLM调用开始
# on_llm_new_token: LLM生成新token
# on_llm_end: LLM调用结束
# on_tool_start: 工具调用开始
# on_tool_end: 工具调用结束
async def categorize_events():
"""分类处理不同事件"""
async for event in app.astream_events({"input": "test"}, version="v1"):
event_type = event["event"]
if event_type == "on_chain_start":
print(f"▶ 开始执行:{event['name']}")
elif event_type == "on_chain_end":
print(f"✓ 完成执行:{event['name']}")
elif event_type == "on_llm_new_token":
print(event["data"]["chunk"], end="", flush=True)
elif event_type == "on_tool_start":
tool_name = event["name"]
tool_input = event["data"]["input"]
print(f"🔧 调用工具:{tool_name}({tool_input})")
elif event_type == "on_tool_end":
tool_output = event["data"]["output"]
print(f" 工具输出:{tool_output}")
# 4. 事件过滤器
class EventFilter:
"""事件过滤器"""
def __init__(self, include_types: list = None, exclude_types: list = None):
self.include_types = include_types
self.exclude_types = exclude_types or []
async def filter_events(self, stream):
"""过滤事件流"""
async for event in stream:
event_type = event["event"]
# 排除类型
if event_type in self.exclude_types:
continue
# 包含类型
if self.include_types is None or event_type in self.include_types:
yield event
# 只监听LLM相关事件
llm_filter = EventFilter(include_types=[
"on_llm_start",
"on_llm_new_token",
"on_llm_end"
])
async for event in llm_filter.filter_events(
app.astream_events({"input": "test"}, version="v1")
):
print(event)
# 5. 事件统计
class EventStats:
"""事件统计"""
def __init__(self):
self.stats = {}
self.start_times = {}
async def track_events(self, stream):
"""跟踪和统计事件"""
import time
async for event in stream:
event_type = event["event"]
name = event.get("name", "unknown")
# 统计事件数量
self.stats[event_type] = self.stats.get(event_type, 0) + 1
# 记录开始时间
if event_type.endswith("_start"):
self.start_times[name] = time.time()
# 计算耗时
if event_type.endswith("_end"):
if name in self.start_times:
duration = time.time() - self.start_times[name]
print(f"{name} 耗时:{duration:.2f}秒")
yield event
def print_stats(self):
"""打印统计信息"""
print("\n事件统计:")
for event_type, count in sorted(self.stats.items()):
print(f" {event_type}: {count}")
stats = EventStats()
async for event in stats.track_events(
app.astream_events({"input": "test"}, version="v1")
):
pass
stats.print_stats()
# 6. 实时UI更新
async def update_ui_with_events():
"""使用事件更新UI"""
ui_state = {
"current_node": "",
"progress": 0,
"llm_output": "",
"tool_calls": []
}
async for event in app.astream_events({"input": "test"}, version="v1"):
event_type = event["event"]
if event_type == "on_chain_start":
ui_state["current_node"] = event["name"]
ui_state["progress"] += 10
update_ui(ui_state)
elif event_type == "on_llm_new_token":
ui_state["llm_output"] += event["data"]["chunk"]
update_ui(ui_state)
elif event_type == "on_tool_start":
ui_state["tool_calls"].append({
"name": event["name"],
"status": "running"
})
update_ui(ui_state)
---
02.进度跟踪
a.节点进度
a.功能说明
通过监听节点开始和结束事件,计算执行进度。统计总节点数、已完成节点数、当前节点等。实时显示进度百分比、进度条、状态信息。进度跟踪提升长时间任务的用户体验,让用户了解执行情况。
b.代码示例
---
import asyncio
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
# 1. 基础进度跟踪
class NodeProgressTracker:
"""节点进度追踪器"""
def __init__(self, total_nodes: int):
self.total_nodes = total_nodes
self.completed_nodes = 0
self.current_node = ""
async def track(self, stream):
"""跟踪进度"""
async for event in stream:
if event["event"] == "on_chain_start":
self.current_node = event["name"]
print(f"开始:{self.current_node}")
elif event["event"] == "on_chain_end":
self.completed_nodes += 1
progress = (self.completed_nodes / self.total_nodes) * 100
print(f"完成:{event['name']} ({progress:.1f}%)")
yield event
# 使用进度跟踪
tracker = NodeProgressTracker(total_nodes=5)
async for event in tracker.track(
app.astream_events({"input": "test"}, version="v1")
):
pass
# 2. Rich进度条
async def run_with_progress_bar():
"""使用Rich进度条"""
with Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn()
) as progress:
# 创建任务
task = progress.add_task("执行中...", total=100)
completed = 0
total_nodes = 5
async for event in app.astream_events({"input": "test"}, version="v1"):
if event["event"] == "on_chain_start":
progress.update(
task,
description=f"执行:{event['name']}"
)
elif event["event"] == "on_chain_end":
completed += 1
progress.update(
task,
completed=(completed / total_nodes) * 100
)
asyncio.run(run_with_progress_bar())
# 3. 多阶段进度
class MultiStageProgress:
"""多阶段进度"""
def __init__(self, stages: dict):
# stages: {node_name: weight}
self.stages = stages
self.total_weight = sum(stages.values())
self.completed_weight = 0
self.current_stage = ""
async def track(self, stream):
"""跟踪多阶段进度"""
async for event in stream:
if event["event"] == "on_chain_start":
self.current_stage = event["name"]
elif event["event"] == "on_chain_end":
node_name = event["name"]
if node_name in self.stages:
self.completed_weight += self.stages[node_name]
progress = (self.completed_weight / self.total_weight) * 100
print(f"总进度:{progress:.1f}%")
yield event
# 定义各阶段权重
stages = {
"load_data": 10,
"process": 50,
"analyze": 30,
"output": 10
}
progress = MultiStageProgress(stages)
async for event in progress.track(
app.astream_events({"input": "test"}, version="v1")
):
pass
# 4. 实时进度API
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import json
app_api = FastAPI()
@app_api.get("/progress/{task_id}")
async def get_progress(task_id: str):
"""SSE进度推送"""
async def event_stream():
completed = 0
total = 5
async for event in app.astream_events(
{"input": task_id},
version="v1"
):
if event["event"] == "on_chain_end":
completed += 1
progress_data = {
"completed": completed,
"total": total,
"percentage": (completed / total) * 100,
"current": event["name"]
}
yield f"data: {json.dumps(progress_data)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
event_stream(),
media_type="text/event-stream"
)
# 5. WebSocket进度推送
from fastapi import WebSocket
@app_api.websocket("/ws/progress")
async def websocket_progress(websocket: WebSocket):
"""WebSocket实时进度"""
await websocket.accept()
try:
data = await websocket.receive_json()
input_data = data.get("input")
completed = 0
total = 5
async for event in app.astream_events(
{"input": input_data},
version="v1"
):
if event["event"] == "on_chain_start":
await websocket.send_json({
"type": "node_start",
"node": event["name"]
})
elif event["event"] == "on_chain_end":
completed += 1
await websocket.send_json({
"type": "progress",
"completed": completed,
"total": total,
"percentage": (completed / total) * 100
})
await websocket.send_json({"type": "complete"})
except Exception as e:
await websocket.send_json({
"type": "error",
"message": str(e)
})
finally:
await websocket.close()
---
6.3 实时反馈
01.Token流式输出
a.打字机效果
a.功能说明
流式输出LLM生成的token,实现打字机效果。监听on_llm_new_token事件获取每个token,实时显示或传输。Token级流式大幅降低首字节延迟,提升交互体验。适用于对话系统、内容生成、实时翻译等场景。
b.代码示例
---
from langgraph.graph import StateGraph, END
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
import asyncio
class LLMStreamState(TypedDict):
input: str
output: str
# 1. 基础token流
async def stream_llm_tokens():
"""流式输出LLM tokens"""
async for event in app.astream_events(
{"input": "讲个故事"},
version="v1"
):
if event["event"] == "on_llm_new_token":
token = event["data"]["chunk"]
print(token, end="", flush=True)
asyncio.run(stream_llm_tokens())
# 2. WebSocket token流
from fastapi import WebSocket
import json
@app_api.websocket("/llm-stream")
async def websocket_llm_stream(websocket: WebSocket):
"""WebSocket LLM token流"""
await websocket.accept()
try:
data = await websocket.receive_json()
user_input = data.get("input")
async for event in app.astream_events(
{"input": user_input},
version="v1"
):
if event["event"] == "on_llm_new_token":
await websocket.send_json({
"type": "token",
"data": event["data"]["chunk"]
})
await websocket.send_json({"type": "done"})
except Exception as e:
await websocket.send_json({
"type": "error",
"message": str(e)
})
finally:
await websocket.close()
# 3. SSE token流
@app_api.get("/llm-stream-sse")
async def sse_llm_stream(query: str):
"""SSE LLM token流"""
async def token_generator():
async for event in app.astream_events(
{"input": query},
version="v1"
):
if event["event"] == "on_llm_new_token":
token = event["data"]["chunk"]
yield f"data: {json.dumps({'token': token})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
token_generator(),
media_type="text/event-stream"
)
# 4. Token缓冲
class TokenBuffer:
"""Token缓冲器"""
def __init__(self, buffer_size: int = 5):
self.buffer_size = buffer_size
self.buffer = ""
async def stream_tokens(self, event_stream):
"""缓冲token后输出"""
async for event in event_stream:
if event["event"] == "on_llm_new_token":
token = event["data"]["chunk"]
self.buffer += token
if len(self.buffer) >= self.buffer_size:
yield self.buffer
self.buffer = ""
if self.buffer:
yield self.buffer
# 5. 前端集成
html_stream = """
<!DOCTYPE html>
<html>
<head>
<title>LLM流式输出</title>
</head>
<body>
<input id="input" type="text">
<button onclick="sendQuery()">发送</button>
<div id="output"></div>
<script>
async function sendQuery() {
const input = document.getElementById('input').value;
const output = document.getElementById('output');
output.textContent = '';
const ws = new WebSocket('ws://localhost:8000/llm-stream');
ws.onopen = () => {
ws.send(JSON.stringify({input: input}));
};
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === 'token') {
output.textContent += data.data;
} else if (data.type === 'done') {
ws.close();
}
};
}
</script>
</body>
</html>
"""
---
02.进度反馈
a.实时状态更新
a.功能说明
流式执行过程中实时更新状态,反馈当前进度、节点信息、中间结果等。通过WebSocket、SSE等技术推送状态更新到客户端。实时状态更新让用户了解执行情况,提升交互体验。支持进度百分比、步骤名称、时间估算等信息。
b.代码示例
---
from fastapi import FastAPI, WebSocket
import json
import asyncio
# 1. WebSocket状态推送
@app_api.websocket("/ws/status")
async def websocket_status(websocket: WebSocket):
"""实时状态推送"""
await websocket.accept()
try:
data = await websocket.receive_json()
input_data = data.get("input")
completed = 0
total = 5
async for event in app.astream_events(
{"input": input_data},
version="v1"
):
if event["event"] == "on_chain_start":
await websocket.send_json({
"type": "status",
"node": event["name"],
"status": "started",
"progress": (completed / total) * 100
})
elif event["event"] == "on_chain_end":
completed += 1
await websocket.send_json({
"type": "status",
"node": event["name"],
"status": "completed",
"progress": (completed / total) * 100
})
elif event["event"] == "on_llm_new_token":
await websocket.send_json({
"type": "token",
"data": event["data"]["chunk"]
})
await websocket.send_json({"type": "done"})
except Exception as e:
await websocket.send_json({
"type": "error",
"message": str(e)
})
finally:
await websocket.close()
# 2. SSE状态流
@app_api.get("/sse/status")
async def sse_status(query: str):
"""SSE状态推送"""
async def status_generator():
completed = 0
total = 5
async for event in app.astream_events(
{"input": query},
version="v1"
):
status_update = None
if event["event"] == "on_chain_start":
status_update = {
"type": "node_start",
"node": event["name"],
"progress": (completed / total) * 100
}
elif event["event"] == "on_chain_end":
completed += 1
status_update = {
"type": "node_end",
"node": event["name"],
"progress": (completed / total) * 100
}
elif event["event"] == "on_llm_new_token":
status_update = {
"type": "token",
"data": event["data"]["chunk"]
}
if status_update:
yield f"data: {json.dumps(status_update)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
status_generator(),
media_type="text/event-stream"
)
# 3. 前端UI更新
html_status = """
<!DOCTYPE html>
<html>
<head>
<title>实时状态</title>
<style>
.progress-bar {
width: 100%;
height: 30px;
background: #f0f0f0;
border-radius: 5px;
}
.progress-fill {
height: 100%;
background: #4CAF50;
transition: width 0.3s;
}
.status {
margin: 10px 0;
padding: 10px;
border-left: 4px solid #2196F3;
}
</style>
</head>
<body>
<h1>执行状态</h1>
<div class="progress-bar">
<div id="progress" class="progress-fill" style="width: 0%"></div>
</div>
<div id="current-node" class="status">准备中...</div>
<div id="output"></div>
<button onclick="start()">开始</button>
<script>
function start() {
const ws = new WebSocket('ws://localhost:8000/ws/status');
ws.onopen = () => {
ws.send(JSON.stringify({input: '开始任务'}));
};
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === 'status') {
// 更新进度条
document.getElementById('progress').style.width =
data.progress + '%';
// 更新当前节点
document.getElementById('current-node').textContent =
`当前:${data.node} (${data.status})`;
}
else if (data.type === 'token') {
// 添加token
document.getElementById('output').textContent += data.data;
}
else if (data.type === 'done') {
document.getElementById('current-node').textContent = '完成!';
ws.close();
}
};
}
</script>
</body>
</html>
"""
@app_api.get("/ui")
async def get_status_ui():
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_status)
# 4. 聚合状态反馈
class StreamStateAggregator:
"""流式状态聚合器"""
def __init__(self):
self.state = {}
self.progress = 0
self.current_node = ""
async def track_and_broadcast(self, event_stream, websocket):
"""跟踪状态并广播"""
async for event in event_stream:
if event["event"] == "on_chain_start":
self.current_node = event["name"]
await websocket.send_json({
"type": "status",
"current_node": self.current_node,
"progress": self.progress
})
elif event["event"] == "on_chain_end":
self.progress += 20
# 更新状态
node_output = event.get("data", {}).get("output", {})
self.state.update(node_output)
await websocket.send_json({
"type": "update",
"progress": self.progress,
"state": self.state
})
# 5. 信创环境适配
# 使用国产浏览器的WebSocket
@app_api.websocket("/ws/status/xinchuang")
async def websocket_status_xinchuang(websocket: WebSocket):
"""信创环境状态推送"""
await websocket.accept()
try:
# 兼容国产浏览器
data = await websocket.receive_text()
input_data = json.loads(data).get("input")
async for event in app.astream_events(
{"input": input_data},
version="v1"
):
# 使用简化的JSON格式
if event["event"] == "on_chain_start":
await websocket.send_text(json.dumps({
"t": "start",
"n": event["name"]
}, ensure_ascii=False))
elif event["event"] == "on_chain_end":
await websocket.send_text(json.dumps({
"t": "end",
"n": event["name"]
}, ensure_ascii=False))
except Exception as e:
await websocket.send_text(json.dumps({
"t": "error",
"m": str(e)
}, ensure_ascii=False))
finally:
await websocket.close()
---
7 高级特性
7.1 子图嵌套
01.子图定义
a.创建子图
a.功能说明
将复杂逻辑封装为独立的子图(SubGraph),然后作为节点嵌入主图。子图有自己的入口、出口、内部节点和边。子图提升代码复用性、可维护性,实现模块化设计。适用于复杂工作流的分层建模,如子流程、可复用组件等。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict
# 1. 定义子图状态
class SubGraphState(TypedDict):
sub_input: str
sub_result: str
# 2. 创建子图
def create_data_processing_subgraph():
"""创建数据处理子图"""
def load_node(state):
return {"sub_result": f"加载:{state['sub_input']}"}
def validate_node(state):
return {"sub_result": state["sub_result"] + " -> 验证通过"}
def transform_node(state):
return {"sub_result": state["sub_result"] + " -> 转换完成"}
# 构建子图
subgraph = StateGraph(SubGraphState)
subgraph.add_node("load", load_node)
subgraph.add_node("validate", validate_node)
subgraph.add_node("transform", transform_node)
subgraph.set_entry_point("load")
subgraph.add_edge("load", "validate")
subgraph.add_edge("validate", "transform")
subgraph.add_edge("transform", END)
return subgraph.compile()
# 3. 在主图中使用子图
class MainState(TypedDict):
input: str
processed_data: str
final_result: str
# 创建子图实例
data_processor = create_data_processing_subgraph()
def prepare_node(state: MainState):
"""准备节点"""
return {"processed_data": "准备完成"}
def process_with_subgraph_node(state: MainState):
"""使用子图处理"""
# 调用子图
result = data_processor.invoke({
"sub_input": state["input"],
"sub_result": ""
})
return {"processed_data": result["sub_result"]}
def finalize_node(state: MainState):
"""最终处理"""
return {"final_result": f"完成:{state['processed_data']}"}
# 构建主图
main_workflow = StateGraph(MainState)
main_workflow.add_node("prepare", prepare_node)
main_workflow.add_node("process", process_with_subgraph_node)
main_workflow.add_node("finalize", finalize_node)
main_workflow.set_entry_point("prepare")
main_workflow.add_edge("prepare", "process")
main_workflow.add_edge("process", "finalize")
main_workflow.add_edge("finalize", END)
app = main_workflow.compile()
# 执行
result = app.invoke({
"input": "原始数据",
"processed_data": "",
"final_result": ""
})
print(result["final_result"])
# 4. 可配置子图
def create_configurable_subgraph(steps: list):
"""创建可配置的子图"""
subgraph = StateGraph(SubGraphState)
# 动态添加节点
for i, step_func in enumerate(steps):
node_name = f"step_{i}"
subgraph.add_node(node_name, step_func)
if i == 0:
subgraph.set_entry_point(node_name)
else:
prev_node = f"step_{i-1}"
subgraph.add_edge(prev_node, node_name)
# 最后一个节点连接END
last_node = f"step_{len(steps)-1}"
subgraph.add_edge(last_node, END)
return subgraph.compile()
# 定义步骤
steps = [
lambda s: {"sub_result": "步骤1"},
lambda s: {"sub_result": s["sub_result"] + " -> 步骤2"},
lambda s: {"sub_result": s["sub_result"] + " -> 步骤3"}
]
configurable_subgraph = create_configurable_subgraph(steps)
# 5. 子图工厂
class SubGraphFactory:
"""子图工厂"""
@staticmethod
def create_validation_subgraph():
"""创建验证子图"""
def check_format(state):
return {"sub_result": "格式检查通过"}
def check_content(state):
return {"sub_result": state["sub_result"] + ", 内容检查通过"}
subgraph = StateGraph(SubGraphState)
subgraph.add_node("format", check_format)
subgraph.add_node("content", check_content)
subgraph.set_entry_point("format")
subgraph.add_edge("format", "content")
subgraph.add_edge("content", END)
return subgraph.compile()
@staticmethod
def create_enrichment_subgraph():
"""创建数据增强子图"""
def add_metadata(state):
return {"sub_result": state["sub_input"] + " [metadata]"}
def add_tags(state):
return {"sub_result": state["sub_result"] + " [tags]"}
subgraph = StateGraph(SubGraphState)
subgraph.add_node("metadata", add_metadata)
subgraph.add_node("tags", add_tags)
subgraph.set_entry_point("metadata")
subgraph.add_edge("metadata", "tags")
subgraph.add_edge("tags", END)
return subgraph.compile()
# 使用工厂
validator = SubGraphFactory.create_validation_subgraph()
enricher = SubGraphFactory.create_enrichment_subgraph()
---
02.状态映射
a.输入输出映射
a.功能说明
主图和子图的状态结构可能不同,需要进行状态映射。定义映射函数转换主图状态到子图输入,转换子图输出回主图状态。状态映射实现图之间的解耦,允许独立设计状态结构。支持字段选择、格式转换、数据聚合等操作。
b.代码示例
---
from typing import TypedDict
# 1. 定义不同的状态结构
class MainGraphState(TypedDict):
user_request: str
analysis_result: dict
final_output: str
class AnalysisSubGraphState(TypedDict):
text: str
sentiment: str
keywords: list
# 2. 状态映射函数
def map_to_subgraph_input(main_state: MainGraphState) -> AnalysisSubGraphState:
"""主图状态 -> 子图输入"""
return {
"text": main_state["user_request"],
"sentiment": "",
"keywords": []
}
def map_from_subgraph_output(sub_result: AnalysisSubGraphState) -> dict:
"""子图输出 -> 主图状态"""
return {
"analysis_result": {
"sentiment": sub_result["sentiment"],
"keywords": sub_result["keywords"],
"analyzed": True
}
}
# 3. 创建分析子图
def create_analysis_subgraph():
"""创建文本分析子图"""
def analyze_sentiment(state: AnalysisSubGraphState):
# 情感分析
sentiment = "positive" if "好" in state["text"] else "neutral"
return {"sentiment": sentiment}
def extract_keywords(state: AnalysisSubGraphState):
# 关键词提取
keywords = state["text"].split()[:3]
return {"keywords": keywords}
subgraph = StateGraph(AnalysisSubGraphState)
subgraph.add_node("sentiment", analyze_sentiment)
subgraph.add_node("keywords", extract_keywords)
subgraph.set_entry_point("sentiment")
subgraph.add_edge("sentiment", "keywords")
subgraph.add_edge("keywords", END)
return subgraph.compile()
# 4. 在主图中使用(带映射)
analysis_subgraph = create_analysis_subgraph()
def call_analysis_node(state: MainGraphState):
"""调用分析子图(带状态映射)"""
# 映射输入
sub_input = map_to_subgraph_input(state)
# 调用子图
sub_result = analysis_subgraph.invoke(sub_input)
# 映射输出
main_update = map_from_subgraph_output(sub_result)
return main_update
# 5. 通用映射器
class StateMapper:
"""通用状态映射器"""
def __init__(self, input_mapping: dict, output_mapping: dict):
# input_mapping: {subgraph_field: main_field}
# output_mapping: {main_field: subgraph_field}
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def map_input(self, main_state: dict) -> dict:
"""映射输入"""
sub_input = {}
for sub_field, main_field in self.input_mapping.items():
if callable(main_field):
# 函数映射
sub_input[sub_field] = main_field(main_state)
else:
# 直接映射
sub_input[sub_field] = main_state.get(main_field, "")
return sub_input
def map_output(self, sub_result: dict) -> dict:
"""映射输出"""
main_update = {}
for main_field, sub_field in self.output_mapping.items():
if callable(sub_field):
# 函数映射
main_update[main_field] = sub_field(sub_result)
else:
# 直接映射
main_update[main_field] = sub_result.get(sub_field, "")
return main_update
# 使用通用映射器
mapper = StateMapper(
input_mapping={
"text": "user_request", # 子图的text <- 主图的user_request
"sentiment": lambda s: "",
"keywords": lambda s: []
},
output_mapping={
"analysis_result": lambda sr: { # 主图的analysis_result <- 子图结果
"sentiment": sr["sentiment"],
"keywords": sr["keywords"]
}
}
)
def mapped_call_node(state: MainGraphState):
"""使用映射器调用子图"""
sub_input = mapper.map_input(state)
sub_result = analysis_subgraph.invoke(sub_input)
main_update = mapper.map_output(sub_result)
return main_update
# 6. 自动映射
def auto_map_subgraph_call(
subgraph,
main_state: dict,
input_fields: dict,
output_field: str
) -> dict:
"""自动映射子图调用"""
# 构造子图输入
sub_input = {
sub_key: main_state.get(main_key, default)
for sub_key, (main_key, default) in input_fields.items()
}
# 调用子图
sub_result = subgraph.invoke(sub_input)
# 返回映射
return {output_field: sub_result}
# 使用自动映射
def auto_mapped_node(state: MainGraphState):
return auto_map_subgraph_call(
analysis_subgraph,
state,
input_fields={
"text": ("user_request", ""),
"sentiment": (None, ""),
"keywords": (None, [])
},
output_field="analysis_result"
)
---
7.2 动态路由
01.基于状态路由
a.条件判断
a.功能说明
根据运行时状态动态决定下一个执行节点。使用conditional_edges定义路由函数,函数检查状态字段返回目标节点名。动态路由实现灵活的工作流,支持分支、循环、跳转等控制流。适用于业务规则复杂、执行路径多样的场景。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict, Literal
class DynamicState(TypedDict):
input: str
score: float
error_count: int
retry_count: int
result: str
# 1. 基础条件路由
def analyze_node(state: DynamicState):
"""分析节点"""
score = calculate_score(state["input"])
return {"score": score}
def route_by_score(state: DynamicState) -> Literal["high_path", "low_path"]:
"""根据分数路由"""
if state["score"] > 0.8:
return "high_path"
else:
return "low_path"
workflow = StateGraph(DynamicState)
workflow.add_node("analyze", analyze_node)
workflow.add_node("high_path", lambda s: {"result": "高分处理"})
workflow.add_node("low_path", lambda s: {"result": "低分处理"})
workflow.set_entry_point("analyze")
workflow.add_conditional_edges(
"analyze",
route_by_score,
{
"high_path": "high_path",
"low_path": "low_path"
}
)
workflow.add_edge("high_path", END)
workflow.add_edge("low_path", END)
# 2. 多条件路由
def multi_condition_route(state: DynamicState) -> str:
"""多条件判断路由"""
score = state.get("score", 0)
error_count = state.get("error_count", 0)
# 优先级判断
if error_count > 3:
return "error_handler"
elif score > 0.9:
return "excellent_path"
elif score > 0.7:
return "good_path"
elif score > 0.5:
return "acceptable_path"
else:
return "retry_path"
workflow.add_conditional_edges(
"analyze",
multi_condition_route,
{
"error_handler": "error_handler",
"excellent_path": "excellent_path",
"good_path": "good_path",
"acceptable_path": "acceptable_path",
"retry_path": "retry_path"
}
)
# 3. 循环与退出
def retry_route(state: DynamicState) -> str:
"""重试路由"""
retry_count = state.get("retry_count", 0)
if retry_count >= 3:
# 超过重试次数,退出
return "end"
else:
# 继续重试
return "retry"
def retry_node(state: DynamicState):
"""重试节点"""
return {
"retry_count": state.get("retry_count", 0) + 1,
"error_count": 0
}
workflow.add_node("retry", retry_node)
workflow.add_conditional_edges(
"retry",
retry_route,
{
"retry": "analyze", # 回到分析节点
"end": END
}
)
# 4. 基于类型的路由
def type_based_route(state: DynamicState) -> str:
"""基于输入类型路由"""
input_data = state["input"]
if input_data.startswith("http"):
return "url_processor"
elif input_data.endswith((".jpg", ".png")):
return "image_processor"
elif input_data.endswith(".pdf"):
return "pdf_processor"
else:
return "text_processor"
# 5. 优先级队列路由
class PriorityRouter:
"""优先级路由器"""
def __init__(self):
self.rules = []
def add_rule(self, condition: callable, target: str, priority: int = 0):
"""添加路由规则"""
self.rules.append({
"condition": condition,
"target": target,
"priority": priority
})
# 按优先级排序
self.rules.sort(key=lambda r: r["priority"], reverse=True)
def route(self, state: dict) -> str:
"""执行路由"""
for rule in self.rules:
if rule["condition"](state):
return rule["target"]
return "default"
# 使用优先级路由
router = PriorityRouter()
router.add_rule(
condition=lambda s: s.get("error_count", 0) > 5,
target="critical_error_handler",
priority=10
)
router.add_rule(
condition=lambda s: s.get("score", 0) > 0.95,
target="fast_track",
priority=5
)
router.add_rule(
condition=lambda s: s.get("score", 0) > 0.5,
target="normal_path",
priority=1
)
def priority_route(state: DynamicState) -> str:
return router.route(state)
---
02.基于LLM路由
a.智能决策
a.功能说明
使用LLM分析状态内容,智能决定执行路径。LLM根据自然语言描述的规则、上下文、历史等因素做出路由决策。适用于规则复杂、难以编码、需要理解语义的路由场景。LLM路由提供最大的灵活性,但增加了延迟和成本。
b.代码示例
---
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
import json
# 1. LLM路由器
class LLMRouter:
"""基于LLM的路由器"""
def __init__(self, routes: dict):
# routes: {route_name: description}
self.routes = routes
self.llm = ChatOpenAI(model="gpt-4", temperature=0)
def route(self, state: dict) -> str:
"""使用LLM决定路由"""
# 构建提示
routes_desc = "\n".join([
f"- {name}: {desc}"
for name, desc in self.routes.items()
])
prompt = ChatPromptTemplate.from_template("""
根据当前状态,选择最合适的处理路径。
可选路径:
{routes}
当前状态:
输入:{input}
分数:{score}
错误次数:{error_count}
只返回路径名称,不要解释。
""")
response = self.llm.predict(
prompt.format(
routes=routes_desc,
input=state.get("input", ""),
score=state.get("score", 0),
error_count=state.get("error_count", 0)
)
)
route_name = response.strip().lower()
# 验证路由名称
if route_name in self.routes:
return route_name
else:
return "default"
# 使用LLM路由
llm_router = LLMRouter({
"expert_review": "需要专家深入审核的复杂内容",
"auto_approve": "简单明确可以自动批准的内容",
"need_revision": "需要修改完善的内容",
"reject": "明显不合格需要拒绝的内容"
})
def llm_route_func(state: DynamicState) -> str:
return llm_router.route(state)
workflow.add_conditional_edges(
"analyze",
llm_route_func,
{
"expert_review": "expert_review_node",
"auto_approve": "approve_node",
"need_revision": "revision_node",
"reject": "reject_node",
"default": "default_node"
}
)
# 2. 意图识别路由
def intent_based_route(state: DynamicState) -> str:
"""基于意图识别的路由"""
llm = ChatOpenAI(model="gpt-4", temperature=0)
user_input = state["input"]
prompt = f"""
识别用户意图,返回以下之一:
- query: 查询信息
- create: 创建内容
- update: 更新数据
- delete: 删除操作
- analyze: 分析数据
用户输入:{user_input}
只返回意图类型:"""
intent = llm.predict(prompt).strip().lower()
return f"{intent}_handler"
# 3. 情感路由
def sentiment_route(state: DynamicState) -> str:
"""基于情感分析路由"""
llm = ChatOpenAI(model="gpt-4", temperature=0)
text = state["input"]
sentiment = llm.predict(f"""
分析以下文本的情感倾向,返回:positive, negative, 或 neutral
文本:{text}
只返回情感类型:""").strip().lower()
if sentiment == "positive":
return "positive_response"
elif sentiment == "negative":
return "negative_response"
else:
return "neutral_response"
# 4. 复杂度路由
def complexity_route(state: DynamicState) -> str:
"""根据任务复杂度路由"""
llm = ChatOpenAI(model="gpt-4", temperature=0)
task = state["input"]
response = llm.predict(f"""
评估任务复杂度(1-10分):
{task}
只返回数字:""")
complexity = int(response.strip())
if complexity <= 3:
return "simple_handler"
elif complexity <= 7:
return "medium_handler"
else:
return "complex_handler"
# 5. 结构化输出路由
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
class RouteDecision(BaseModel):
route: str = Field(description="选择的路由名称")
confidence: float = Field(description="决策置信度 0-1")
reason: str = Field(description="路由原因")
def structured_llm_route(state: DynamicState) -> str:
"""使用结构化输出的LLM路由"""
llm = ChatOpenAI(model="gpt-4", temperature=0)
parser = PydanticOutputParser(pydantic_object=RouteDecision)
prompt = ChatPromptTemplate.from_template("""
分析当前情况,决定路由。
可选路由:expert_review, auto_approve, reject
当前状态:{state}
{format_instructions}
""")
response = llm.predict(
prompt.format(
state=json.dumps(state, ensure_ascii=False),
format_instructions=parser.get_format_instructions()
)
)
decision = parser.parse(response)
# 记录决策
print(f"路由:{decision.route}, 置信度:{decision.confidence}, 原因:{decision.reason}")
return decision.route
---
7.3 循环处理
01.循环控制
a.计数循环
a.功能说明
使用状态字段记录循环次数,在路由函数中判断是否继续循环或退出。设置最大迭代次数防止无限循环。计数循环适用于固定次数的迭代、批处理、重试等场景。循环控制确保流程可控,避免资源耗尽。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict
class LoopState(TypedDict):
items: list
current_index: int
processed_items: list
max_iterations: int
# 1. 基础计数循环
def process_item_node(state: LoopState):
"""处理当前项"""
items = state["items"]
index = state["current_index"]
if index < len(items):
item = items[index]
processed = process_single_item(item)
return {
"processed_items": state["processed_items"] + [processed],
"current_index": index + 1
}
else:
return {}
def loop_condition(state: LoopState) -> str:
"""循环条件判断"""
if state["current_index"] < len(state["items"]):
return "continue"
else:
return "end"
workflow = StateGraph(LoopState)
workflow.add_node("process", process_item_node)
workflow.set_entry_point("process")
workflow.add_conditional_edges(
"process",
loop_condition,
{
"continue": "process", # 循环回自己
"end": END
}
)
app = workflow.compile()
# 执行
result = app.invoke({
"items": ["item1", "item2", "item3"],
"current_index": 0,
"processed_items": [],
"max_iterations": 10
})
print(f"处理结果:{result['processed_items']}")
# 2. 条件循环
def while_loop_condition(state: LoopState) -> str:
"""while循环条件"""
current_index = state.get("current_index", 0)
max_iterations = state.get("max_iterations", 100)
# 检查是否满足继续条件
should_continue = (
current_index < max_iterations and
state.get("error_count", 0) < 3 and
not state.get("target_reached", False)
)
return "continue" if should_continue else "end"
# 3. 带退出条件的循环
class OptimizationState(TypedDict):
value: float
iteration: int
converged: bool
history: list
def optimize_node(state: OptimizationState):
"""优化迭代节点"""
current_value = state["value"]
# 执行优化步骤
new_value = optimize_step(current_value)
# 检查是否收敛
converged = abs(new_value - current_value) < 0.001
return {
"value": new_value,
"iteration": state["iteration"] + 1,
"converged": converged,
"history": state["history"] + [new_value]
}
def optimization_condition(state: OptimizationState) -> str:
"""优化循环条件"""
if state.get("converged"):
return "converged"
elif state["iteration"] >= 100:
return "max_iterations"
else:
return "continue"
opt_workflow = StateGraph(OptimizationState)
opt_workflow.add_node("optimize", optimize_node)
opt_workflow.set_entry_point("optimize")
opt_workflow.add_conditional_edges(
"optimize",
optimization_condition,
{
"continue": "optimize",
"converged": END,
"max_iterations": END
}
)
# 4. 嵌套循环
class NestedLoopState(TypedDict):
outer_index: int
inner_index: int
outer_items: list
inner_items: list
results: list
def inner_loop_node(state: NestedLoopState):
"""内层循环节点"""
outer_idx = state["outer_index"]
inner_idx = state["inner_index"]
outer_item = state["outer_items"][outer_idx]
inner_item = state["inner_items"][inner_idx]
result = process_pair(outer_item, inner_item)
return {
"results": state["results"] + [result],
"inner_index": inner_idx + 1
}
def nested_loop_condition(state: NestedLoopState) -> str:
"""嵌套循环条件"""
inner_idx = state["inner_index"]
outer_idx = state["outer_index"]
inner_len = len(state["inner_items"])
outer_len = len(state["outer_items"])
if inner_idx < inner_len:
# 内层继续
return "inner_continue"
elif outer_idx + 1 < outer_len:
# 内层结束,外层继续
return "outer_continue"
else:
# 全部结束
return "end"
def reset_inner_node(state: NestedLoopState):
"""重置内层循环"""
return {
"inner_index": 0,
"outer_index": state["outer_index"] + 1
}
nested_workflow = StateGraph(NestedLoopState)
nested_workflow.add_node("inner_loop", inner_loop_node)
nested_workflow.add_node("reset_inner", reset_inner_node)
nested_workflow.set_entry_point("inner_loop")
nested_workflow.add_conditional_edges(
"inner_loop",
nested_loop_condition,
{
"inner_continue": "inner_loop",
"outer_continue": "reset_inner",
"end": END
}
)
nested_workflow.add_edge("reset_inner", "inner_loop")
# 5. 循环限制器
class LoopLimiter:
"""循环限制器"""
def __init__(self, max_iterations: int = 100, timeout_seconds: float = 60):
self.max_iterations = max_iterations
self.timeout_seconds = timeout_seconds
self.start_time = None
self.iteration_count = 0
def start(self):
"""开始计时"""
import time
self.start_time = time.time()
self.iteration_count = 0
def check(self) -> tuple[bool, str]:
"""检查是否应该继续"""
import time
self.iteration_count += 1
# 检查迭代次数
if self.iteration_count >= self.max_iterations:
return False, "max_iterations_reached"
# 检查超时
if self.start_time and (time.time() - self.start_time) > self.timeout_seconds:
return False, "timeout"
return True, "ok"
# 使用限制器
limiter = LoopLimiter(max_iterations=100, timeout_seconds=30)
limiter.start()
def limited_loop_condition(state: LoopState) -> str:
"""带限制的循环条件"""
can_continue, reason = limiter.check()
if not can_continue:
print(f"循环终止:{reason}")
return "end"
# 业务条件
if state["current_index"] < len(state["items"]):
return "continue"
else:
return "end"
---
02.批处理循环
a.分批处理
a.功能说明
将大量数据分批处理,每批次执行一次循环迭代。设置batch_size控制每批处理的数量,使用offset跟踪进度。分批处理避免内存溢出,支持断点续传,适用于大规模数据处理、ETL任务、批量操作等场景。
b.代码示例
---
from typing import TypedDict, List
class BatchState(TypedDict):
all_items: list
batch_size: int
offset: int
processed_count: int
results: list
errors: list
# 1. 基础批处理
def batch_process_node(state: BatchState):
"""批处理节点"""
all_items = state["all_items"]
batch_size = state["batch_size"]
offset = state["offset"]
# 获取当前批次
batch = all_items[offset:offset + batch_size]
# 处理批次
batch_results = []
batch_errors = []
for item in batch:
try:
result = process_item(item)
batch_results.append(result)
except Exception as e:
batch_errors.append({"item": item, "error": str(e)})
return {
"offset": offset + batch_size,
"processed_count": state["processed_count"] + len(batch),
"results": state["results"] + batch_results,
"errors": state["errors"] + batch_errors
}
def batch_condition(state: BatchState) -> str:
"""批处理循环条件"""
if state["offset"] < len(state["all_items"]):
return "continue"
else:
return "end"
batch_workflow = StateGraph(BatchState)
batch_workflow.add_node("process_batch", batch_process_node)
batch_workflow.set_entry_point("process_batch")
batch_workflow.add_conditional_edges(
"process_batch",
batch_condition,
{
"continue": "process_batch",
"end": END
}
)
batch_app = batch_workflow.compile()
# 执行批处理
large_dataset = list(range(1000))
result = batch_app.invoke({
"all_items": large_dataset,
"batch_size": 50,
"offset": 0,
"processed_count": 0,
"results": [],
"errors": []
})
print(f"处理了{result['processed_count']}个项目")
print(f"成功{len(result['results'])}个,失败{len(result['errors'])}个")
# 2. 并行批处理(实际需要异步支持)
import asyncio
async def parallel_batch_process(all_items: list, batch_size: int):
"""并行批处理"""
batches = [
all_items[i:i+batch_size]
for i in range(0, len(all_items), batch_size)
]
async def process_single_batch(batch):
"""处理单个批次"""
results = []
for item in batch:
result = await async_process_item(item)
results.append(result)
return results
# 并行处理所有批次
all_results = await asyncio.gather(*[
process_single_batch(batch)
for batch in batches
])
# 展平结果
flat_results = [item for batch in all_results for item in batch]
return flat_results
# 3. 带进度的批处理
from rich.progress import Progress
def batch_with_progress(all_items: list, batch_size: int):
"""带进度条的批处理"""
total_batches = (len(all_items) + batch_size - 1) // batch_size
with Progress() as progress:
task = progress.add_task("处理中...", total=total_batches)
results = []
for i in range(0, len(all_items), batch_size):
batch = all_items[i:i+batch_size]
batch_results = [process_item(item) for item in batch]
results.extend(batch_results)
progress.update(task, advance=1)
return results
# 4. 流式批处理
def streaming_batch_process(all_items: list, batch_size: int):
"""流式批处理(yield每批结果)"""
for i in range(0, len(all_items), batch_size):
batch = all_items[i:i+batch_size]
batch_results = [process_item(item) for item in batch]
yield {
"batch_index": i // batch_size,
"results": batch_results,
"progress": min((i + batch_size) / len(all_items), 1.0)
}
# 使用流式批处理
for batch_output in streaming_batch_process(large_dataset, batch_size=50):
print(f"批次{batch_output['batch_index']}完成,"
f"进度:{batch_output['progress']*100:.1f}%")
# 5. 自适应批处理
class AdaptiveBatchProcessor:
"""自适应批处理"""
def __init__(self, initial_batch_size: int = 10):
self.batch_size = initial_batch_size
self.processing_times = []
def adjust_batch_size(self, processing_time: float):
"""根据处理时间调整批次大小"""
self.processing_times.append(processing_time)
if len(self.processing_times) >= 3:
avg_time = sum(self.processing_times[-3:]) / 3
# 如果太快,增大批次
if avg_time < 1.0:
self.batch_size = min(self.batch_size * 2, 1000)
# 如果太慢,减小批次
elif avg_time > 5.0:
self.batch_size = max(self.batch_size // 2, 1)
def process_all(self, items: list):
"""自适应批处理所有项"""
import time
offset = 0
results = []
while offset < len(items):
start_time = time.time()
# 处理当前批次
batch = items[offset:offset + self.batch_size]
batch_results = [process_item(item) for item in batch]
results.extend(batch_results)
# 调整批次大小
processing_time = time.time() - start_time
self.adjust_batch_size(processing_time)
offset += self.batch_size
print(f"批次大小:{self.batch_size}, 耗时:{processing_time:.2f}秒")
return results
# 使用自适应批处理
processor = AdaptiveBatchProcessor(initial_batch_size=10)
results = processor.process_all(large_dataset)
---
7.4 并行执行
01.并行节点
a.独立并行
a.功能说明
多个节点之间没有依赖关系,可以并行执行。LangGraph自动识别并行机会,同时执行独立节点。并行执行减少总耗时,提升吞吐量。适用于独立的API调用、数据查询、文件操作等可并行的任务。需要确保节点间状态更新互不冲突。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
class ParallelState(TypedDict):
input: str
results: Annotated[list, operator.add]
api_result: str
db_result: str
file_result: str
# 1. 定义独立节点
def api_call_node(state: ParallelState):
"""API调用节点"""
import time
time.sleep(2) # 模拟API调用
return {
"api_result": "API数据",
"results": ["api完成"]
}
def db_query_node(state: ParallelState):
"""数据库查询节点"""
import time
time.sleep(2) # 模拟查询
return {
"db_result": "DB数据",
"results": ["db完成"]
}
def file_read_node(state: ParallelState):
"""文件读取节点"""
import time
time.sleep(2) # 模拟文件读取
return {
"file_result": "文件数据",
"results": ["file完成"]
}
def aggregate_node(state: ParallelState):
"""聚合节点"""
final_result = {
"api": state["api_result"],
"db": state["db_result"],
"file": state["file_result"]
}
return {"results": state["results"] + [f"聚合完成:{final_result}"]}
# 2. 构建并行图
workflow = StateGraph(ParallelState)
workflow.add_node("api", api_call_node)
workflow.add_node("db", db_query_node)
workflow.add_node("file", file_read_node)
workflow.add_node("aggregate", aggregate_node)
# 设置多个入口点(并行执行)
workflow.set_entry_point("api")
workflow.set_entry_point("db")
workflow.set_entry_point("file")
# 三个并行节点都连接到聚合节点
workflow.add_edge("api", "aggregate")
workflow.add_edge("db", "aggregate")
workflow.add_edge("file", "aggregate")
workflow.add_edge("aggregate", END)
app = workflow.compile()
# 执行(api、db、file会并行执行)
import time
start = time.time()
result = app.invoke({
"input": "test",
"results": [],
"api_result": "",
"db_result": "",
"file_result": ""
})
elapsed = time.time() - start
print(f"总耗时:{elapsed:.2f}秒") # 约2秒(而非6秒)
print(f"结果:{result['results']}")
# 3. 显式并行分支
# 注意:实际并行执行需要LangGraph底层支持
# 这里展示概念性代码
def split_node(state: ParallelState):
"""分发节点"""
return {"input": state["input"]}
workflow2 = StateGraph(ParallelState)
workflow2.add_node("split", split_node)
workflow2.add_node("branch_a", lambda s: {"results": ["分支A"]})
workflow2.add_node("branch_b", lambda s: {"results": ["分支B"]})
workflow2.add_node("branch_c", lambda s: {"results": ["分支C"]})
workflow2.add_node("merge", aggregate_node)
workflow2.set_entry_point("split")
# 从split并行到三个分支
workflow2.add_edge("split", "branch_a")
workflow2.add_edge("split", "branch_b")
workflow2.add_edge("split", "branch_c")
# 三个分支汇总到merge
workflow2.add_edge("branch_a", "merge")
workflow2.add_edge("branch_b", "merge")
workflow2.add_edge("branch_c", "merge")
workflow2.add_edge("merge", END)
# 4. 动态并行
class DynamicParallelState(TypedDict):
tasks: list
results: Annotated[dict, operator.add]
def dynamic_parallel_node(state: DynamicParallelState):
"""动态并行执行多个任务"""
import concurrent.futures
tasks = state["tasks"]
def execute_task(task):
"""执行单个任务"""
return {"id": task["id"], "result": process_task(task)}
# 使用线程池并行执行
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(execute_task, task) for task in tasks]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
return {"results": {r["id"]: r["result"] for r in results}}
# 5. 限制并发度
import asyncio
from asyncio import Semaphore
async def limited_parallel_execution(tasks: list, max_concurrent: int = 5):
"""限制并发度的并行执行"""
semaphore = Semaphore(max_concurrent)
async def execute_with_semaphore(task):
"""带信号量的执行"""
async with semaphore:
return await async_process_task(task)
results = await asyncio.gather(*[
execute_with_semaphore(task)
for task in tasks
])
return results
---
02.扇出扇入
a.Map-Reduce模式
a.功能说明
扇出(Fan-out)将任务分发到多个并行分支,扇入(Fan-in)收集各分支结果进行聚合。Map-Reduce是经典的并行处理模式,Map阶段并行处理数据分片,Reduce阶段聚合结果。适用于大数据处理、批量API调用、分布式计算等场景。
b.代码示例
---
from typing import TypedDict, List, Annotated
import operator
class MapReduceState(TypedDict):
input_data: list
map_results: Annotated[list, operator.add]
reduce_result: dict
# 1. Map阶段
def map_node(state: MapReduceState):
"""Map节点:并行处理数据分片"""
input_data = state["input_data"]
# 将数据分片并行处理
import concurrent.futures
def map_function(item):
"""Map函数"""
return {
"key": item["category"],
"value": item["amount"]
}
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
map_results = list(executor.map(map_function, input_data))
return {"map_results": map_results}
# 2. Reduce阶段
def reduce_node(state: MapReduceState):
"""Reduce节点:聚合Map结果"""
map_results = state["map_results"]
# 按key聚合
aggregated = {}
for item in map_results:
key = item["key"]
value = item["value"]
if key not in aggregated:
aggregated[key] = []
aggregated[key].append(value)
# Reduce操作(求和)
final_result = {
key: sum(values)
for key, values in aggregated.items()
}
return {"reduce_result": final_result}
# 构建Map-Reduce图
mr_workflow = StateGraph(MapReduceState)
mr_workflow.add_node("map", map_node)
mr_workflow.add_node("reduce", reduce_node)
mr_workflow.set_entry_point("map")
mr_workflow.add_edge("map", "reduce")
mr_workflow.add_edge("reduce", END)
mr_app = mr_workflow.compile()
# 执行
input_data = [
{"category": "A", "amount": 10},
{"category": "B", "amount": 20},
{"category": "A", "amount": 15},
{"category": "C", "amount": 30},
{"category": "B", "amount": 25}
]
result = mr_app.invoke({
"input_data": input_data,
"map_results": [],
"reduce_result": {}
})
print(f"聚合结果:{result['reduce_result']}")
# 输出:{'A': 25, 'B': 45, 'C': 30}
# 3. 多阶段Map-Reduce
class MultiStageState(TypedDict):
input: list
stage1_results: Annotated[list, operator.add]
stage2_results: Annotated[list, operator.add]
final_result: dict
def stage1_map(state):
"""第一阶段Map"""
return {"stage1_results": [process_stage1(item) for item in state["input"]]}
def stage1_reduce(state):
"""第一阶段Reduce"""
intermediate = aggregate_stage1(state["stage1_results"])
return {"stage1_results": [intermediate]}
def stage2_map(state):
"""第二阶段Map"""
return {"stage2_results": [process_stage2(item) for item in state["stage1_results"]]}
def stage2_reduce(state):
"""第二阶段Reduce"""
final = aggregate_stage2(state["stage2_results"])
return {"final_result": final}
# 4. 分片并行处理
def shard_based_processing(data: list, num_shards: int = 4):
"""基于分片的并行处理"""
import math
# 计算分片大小
shard_size = math.ceil(len(data) / num_shards)
# 创建分片
shards = [
data[i:i+shard_size]
for i in range(0, len(data), shard_size)
]
# 并行处理分片
import concurrent.futures
def process_shard(shard):
"""处理单个分片"""
return [process_item(item) for item in shard]
with concurrent.futures.ThreadPoolExecutor(max_workers=num_shards) as executor:
shard_results = list(executor.map(process_shard, shards))
# 合并结果
final_results = [item for shard in shard_results for item in shard]
return final_results
# 5. 实时流式Map-Reduce
async def streaming_map_reduce(data_stream):
"""流式Map-Reduce"""
import asyncio
# Map阶段:流式处理
async def map_stream():
async for item in data_stream:
mapped = await async_map_function(item)
yield mapped
# 收集Map结果
map_results = []
async for mapped in map_stream():
map_results.append(mapped)
# 当累积足够多时,执行部分Reduce
if len(map_results) >= 100:
partial_result = reduce_function(map_results)
yield partial_result
map_results.clear()
# 最终Reduce
if map_results:
final_result = reduce_function(map_results)
yield final_result
# 6. 容错并行处理
class FaultTolerantParallel:
"""容错并行处理"""
def __init__(self, max_retries: int = 3):
self.max_retries = max_retries
async def process_with_retry(self, item):
"""带重试的处理"""
for attempt in range(self.max_retries):
try:
return await async_process_item(item)
except Exception as e:
if attempt == self.max_retries - 1:
return {"item": item, "error": str(e)}
await asyncio.sleep(2 ** attempt)
async def parallel_process(self, items: list):
"""并行处理with容错"""
tasks = [self.process_with_retry(item) for item in items]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 分离成功和失败
successes = [r for r in results if not isinstance(r, Exception)]
failures = [r for r in results if isinstance(r, Exception)]
return {
"successes": successes,
"failures": failures,
"success_rate": len(successes) / len(items) if items else 0
}
# 使用容错并行
processor = FaultTolerantParallel(max_retries=3)
result = asyncio.run(processor.parallel_process(large_dataset))
print(f"成功:{len(result['successes'])}, "
f"失败:{len(result['failures'])}, "
f"成功率:{result['success_rate']*100:.1f}%")
---
8 实战案例
8.1 评标专家组
01.场景分析
a.业务需求
a.功能说明
政府采购和工程招标需要评标专家组进行多轮评审打分。专家组由技术、商务、法律等不同领域专家组成,独立评审后汇总。流程包括资格预审、技术评审、商务评审、综合评分、专家讨论、最终决策等环节。要求全程可追溯、公正透明、符合法规。信创环境需适配国产浏览器、达梦数据库、麒麟操作系统。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Annotated
import operator
from datetime import datetime
# 1. 定义状态
class BiddingEvaluationState(TypedDict):
# 项目信息
project_id: str
project_name: str
bidders: list # 投标人列表
# 专家组
experts: list # 专家列表
expert_scores: Annotated[dict, operator.add]
# 评审阶段
current_stage: str
qualification_results: dict
technical_scores: dict
commercial_scores: dict
comprehensive_scores: dict
# 讨论记录
discussion_records: Annotated[list, operator.add]
# 最终结果
final_ranking: list
winner: str
evaluation_report: str
# 2. 资格预审节点
def qualification_review_node(state: BiddingEvaluationState):
"""资格预审"""
bidders = state["bidders"]
qualified = {}
for bidder in bidders:
# 检查资质文件
checks = {
"营业执照": check_business_license(bidder),
"资质证书": check_certificates(bidder),
"财务状况": check_financial_status(bidder),
"业绩证明": check_performance_records(bidder)
}
# 全部通过才合格
qualified[bidder["id"]] = all(checks.values())
return {
"qualification_results": qualified,
"current_stage": "qualification_complete"
}
# 3. 技术评审节点
def technical_review_node(state: BiddingEvaluationState):
"""技术评审"""
qualified_bidders = [
bidder for bidder in state["bidders"]
if state["qualification_results"].get(bidder["id"])
]
expert_scores = {}
for expert in state["experts"]:
if expert["specialty"] == "技术":
# 专家独立打分
scores = {}
for bidder in qualified_bidders:
score = expert_technical_score(expert, bidder, {
"技术方案": 30,
"实施能力": 20,
"创新性": 15,
"可行性": 15
})
scores[bidder["id"]] = score
expert_scores[expert["id"]] = scores
return {
"expert_scores": expert_scores,
"current_stage": "technical_complete"
}
# 4. 商务评审节点
def commercial_review_node(state: BiddingEvaluationState):
"""商务评审"""
qualified_bidders = [
bidder for bidder in state["bidders"]
if state["qualification_results"].get(bidder["id"])
]
expert_scores = {}
for expert in state["experts"]:
if expert["specialty"] == "商务":
scores = {}
for bidder in qualified_bidders:
score = expert_commercial_score(expert, bidder, {
"报价合理性": 30,
"付款条件": 10,
"履约保证": 10,
"售后服务": 10
})
scores[bidder["id"]] = score
expert_scores[expert["id"]] = scores
return {
"expert_scores": expert_scores,
"current_stage": "commercial_complete"
}
# 5. 综合评分节点
def comprehensive_scoring_node(state: BiddingEvaluationState):
"""综合评分"""
all_expert_scores = state["expert_scores"]
bidders = [b for b in state["bidders"] if state["qualification_results"].get(b["id"])]
final_scores = {}
for bidder in bidders:
bidder_id = bidder["id"]
# 收集所有专家对该投标人的打分
scores = []
for expert_id, expert_score_dict in all_expert_scores.items():
if bidder_id in expert_score_dict:
scores.append(expert_score_dict[bidder_id])
# 计算平均分(去掉最高最低分)
if len(scores) > 2:
scores_sorted = sorted(scores)
avg_score = sum(scores_sorted[1:-1]) / (len(scores_sorted) - 2)
else:
avg_score = sum(scores) / len(scores) if scores else 0
final_scores[bidder_id] = round(avg_score, 2)
return {
"comprehensive_scores": final_scores,
"current_stage": "scoring_complete"
}
# 6. 专家讨论节点(Human-in-the-Loop)
def expert_discussion_node(state: BiddingEvaluationState):
"""专家讨论环节(中断等待)"""
# 在这个节点中断,等待专家讨论
# 如果有讨论记录,说明已讨论完毕
if state.get("discussion_records"):
return {"current_stage": "discussion_complete"}
else:
# 等待专家讨论
return {"current_stage": "awaiting_discussion"}
# 7. 最终决策节点
def final_decision_node(state: BiddingEvaluationState):
"""最终决策"""
scores = state["comprehensive_scores"]
# 按分数排名
ranking = sorted(
scores.items(),
key=lambda x: x[1],
reverse=True
)
# 确定中标人(第一名)
winner_id = ranking[0][0] if ranking else None
winner = next((b for b in state["bidders"] if b["id"] == winner_id), None)
# 生成评审报告
report = generate_evaluation_report(state, ranking)
return {
"final_ranking": ranking,
"winner": winner["name"] if winner else "",
"evaluation_report": report,
"current_stage": "completed"
}
# 8. 构建评标流程图
workflow = StateGraph(BiddingEvaluationState)
workflow.add_node("qualification", qualification_review_node)
workflow.add_node("technical", technical_review_node)
workflow.add_node("commercial", commercial_review_node)
workflow.add_node("scoring", comprehensive_scoring_node)
workflow.add_node("discussion", expert_discussion_node)
workflow.add_node("decision", final_decision_node)
workflow.set_entry_point("qualification")
workflow.add_edge("qualification", "technical")
workflow.add_edge("technical", "commercial")
workflow.add_edge("commercial", "scoring")
workflow.add_edge("scoring", "discussion")
# 讨论后路由
def after_discussion(state):
return "decision" if state["current_stage"] == "discussion_complete" else "discussion"
workflow.add_conditional_edges(
"discussion",
after_discussion,
{
"decision": "decision",
"discussion": "discussion"
}
)
workflow.add_edge("decision", END)
# 配置检查点和中断
from langgraph.checkpoint.sqlite import SqliteSaver
checkpointer = SqliteSaver.from_conn_string("bidding_evaluation.db")
app = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["discussion"] # 在讨论前中断
)
# 9. 执行评标流程
config = {"configurable": {"thread_id": "project_2024001"}}
# 第一阶段:自动执行到讨论环节
result = app.invoke({
"project_id": "2024001",
"project_name": "信创平台建设项目",
"bidders": [
{"id": "A", "name": "公司A", "documents": {...}},
{"id": "B", "name": "公司B", "documents": {...}},
{"id": "C", "name": "公司C", "documents": {...}}
],
"experts": [
{"id": "E1", "name": "专家1", "specialty": "技术"},
{"id": "E2", "name": "专家2", "specialty": "技术"},
{"id": "E3", "name": "专家3", "specialty": "商务"},
{"id": "E4", "name": "专家4", "specialty": "商务"},
{"id": "E5", "name": "专家5", "specialty": "法律"}
],
"expert_scores": {},
"current_stage": "initial",
"qualification_results": {},
"technical_scores": {},
"commercial_scores": {},
"comprehensive_scores": {},
"discussion_records": [],
"final_ranking": [],
"winner": "",
"evaluation_report": ""
}, config=config)
print(f"当前阶段:{result['current_stage']}")
print(f"综合评分:{result['comprehensive_scores']}")
# 第二阶段:专家讨论(人工介入)
discussion_record = {
"timestamp": datetime.now().isoformat(),
"topic": "对公司B的技术方案有争议",
"opinions": [
{"expert": "专家1", "opinion": "技术方案创新性不足"},
{"expert": "专家2", "opinion": "实施方案详细,建议加分"}
],
"conclusion": "维持原评分"
}
# 继续执行
final_result = app.invoke({
"discussion_records": [discussion_record]
}, config=config)
print(f"中标人:{final_result['winner']}")
print(f"最终排名:{final_result['final_ranking']}")
---
02.信创适配
a.达梦数据库
a.功能说明
评标数据需持久化到达梦数据库,记录完整的评审过程、专家打分、讨论记录等。使用dmPython连接达梦数据库,创建自定义Checkpointer持久化状态。适配达梦的SQL语法、数据类型、事务处理等特性。确保数据安全、审计追溯、容灾备份符合信创要求。
b.代码示例
---
import dmPython
from langgraph.checkpoint.base import BaseCheckpointSaver
import json
from datetime import datetime
# 1. 达梦数据库Checkpointer
class DmBiddingCheckpointer(BaseCheckpointSaver):
"""达梦数据库评标检查点存储"""
def __init__(self, connection_string: str):
self.conn_str = connection_string
self._ensure_tables()
def _ensure_tables(self):
"""创建数据库表"""
conn = dmPython.connect(self.conn_str)
cursor = conn.cursor()
# 评标项目表
cursor.execute("""
CREATE TABLE IF NOT EXISTS bidding_projects (
project_id VARCHAR(100) PRIMARY KEY,
project_name VARCHAR(200),
project_data TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# 评标检查点表
cursor.execute("""
CREATE TABLE IF NOT EXISTS bidding_checkpoints (
checkpoint_id VARCHAR(100) PRIMARY KEY,
project_id VARCHAR(100),
checkpoint_data TEXT,
stage VARCHAR(50),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (project_id) REFERENCES bidding_projects(project_id)
)
""")
# 专家评分表
cursor.execute("""
CREATE TABLE IF NOT EXISTS expert_scores (
score_id INTEGER IDENTITY(1,1) PRIMARY KEY,
project_id VARCHAR(100),
expert_id VARCHAR(50),
bidder_id VARCHAR(50),
score_type VARCHAR(20),
score DECIMAL(5,2),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# 讨论记录表
cursor.execute("""
CREATE TABLE IF NOT EXISTS discussion_records (
record_id INTEGER IDENTITY(1,1) PRIMARY KEY,
project_id VARCHAR(100),
discussion_data TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
conn.commit()
conn.close()
def put(self, config, checkpoint, metadata):
"""保存检查点"""
project_id = config["configurable"]["thread_id"]
checkpoint_id = checkpoint.get("id", "")
conn = dmPython.connect(self.conn_str)
cursor = conn.cursor()
try:
# 检查点数据
checkpoint_json = json.dumps(checkpoint, ensure_ascii=False)
stage = checkpoint.get("channel_values", {}).get("current_stage", "")
# 删除旧检查点
cursor.execute(
"DELETE FROM bidding_checkpoints WHERE checkpoint_id = ?",
(checkpoint_id,)
)
# 插入新检查点
cursor.execute("""
INSERT INTO bidding_checkpoints
(checkpoint_id, project_id, checkpoint_data, stage)
VALUES (?, ?, ?, ?)
""", (checkpoint_id, project_id, checkpoint_json, stage))
# 保存专家评分
state = checkpoint.get("channel_values", {})
expert_scores = state.get("expert_scores", {})
for expert_id, scores_dict in expert_scores.items():
for bidder_id, score in scores_dict.items():
cursor.execute("""
INSERT INTO expert_scores
(project_id, expert_id, bidder_id, score_type, score)
VALUES (?, ?, ?, ?, ?)
""", (project_id, expert_id, bidder_id, "综合", score))
conn.commit()
finally:
conn.close()
def get(self, config):
"""获取最新检查点"""
project_id = config["configurable"]["thread_id"]
conn = dmPython.connect(self.conn_str)
cursor = conn.cursor()
try:
cursor.execute("""
SELECT checkpoint_data
FROM bidding_checkpoints
WHERE project_id = ?
ORDER BY created_at DESC
LIMIT 1
""", (project_id,))
result = cursor.fetchone()
if result:
return json.loads(result[0])
return None
finally:
conn.close()
# 使用达梦检查点
dm_checkpointer = DmBiddingCheckpointer(
"dm://SYSDBA:SYSDBA@localhost:5236/BIDDING"
)
app = workflow.compile(checkpointer=dm_checkpointer)
# 2. 国密算法加密
from gmssl import sm4
def encrypt_sensitive_data(data: str, key: bytes) -> str:
"""使用SM4加密敏感数据"""
cipher = sm4.CryptSM4()
cipher.set_key(key, sm4.SM4_ENCRYPT)
encrypted = cipher.crypt_ecb(data.encode('utf-8'))
return encrypted.hex()
def decrypt_sensitive_data(encrypted_hex: str, key: bytes) -> str:
"""使用SM4解密"""
cipher = sm4.CryptSM4()
cipher.set_key(key, sm4.SM4_DECRYPT)
encrypted_bytes = bytes.fromhex(encrypted_hex)
decrypted = cipher.crypt_ecb(encrypted_bytes)
return decrypted.decode('utf-8')
# 加密专家评分
sm4_key = b'0123456789abcdef' # 16字节密钥
score_data = json.dumps(expert_scores)
encrypted_score = encrypt_sensitive_data(score_data, sm4_key)
# 存储加密数据
cursor.execute(
"INSERT INTO expert_scores (project_id, encrypted_data) VALUES (?, ?)",
(project_id, encrypted_score)
)
---
8.2 客服工作流
01.智能客服架构
a.意图识别路由
a.功能说明
智能客服系统使用LangGraph实现多轮对话、意图识别、工单创建、人工介入等功能。首先识别用户意图(咨询、投诉、建议等),路由到相应处理节点。支持FAQ自动回复、知识库检索、订单查询、退换货处理等常见场景。复杂问题自动转人工客服,记录完整对话历史。信创环境使用Ollama本地部署LLM,保障数据安全。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, Literal
import operator
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, AIMessage
# 1. 定义客服状态
class CustomerServiceState(TypedDict):
# 用户信息
user_id: str
user_name: str
user_level: str # VIP, 普通
# 对话历史
messages: Annotated[list, operator.add]
# 意图识别
intent: str # inquiry, complaint, suggestion, order_query
confidence: float
# 业务数据
order_info: dict
faq_result: str
knowledge_base_result: str
# 工单
ticket_id: str
ticket_status: str
# 转人工
needs_human: bool
agent_id: str
# 满意度
satisfaction: int
# 2. 意图识别节点
def intent_recognition_node(state: CustomerServiceState):
"""意图识别"""
messages = state["messages"]
last_message = messages[-1] if messages else ""
llm = ChatOpenAI(model="gpt-4", temperature=0)
prompt = f"""
分析用户消息的意图,返回JSON格式:
{{"intent": "inquiry/complaint/suggestion/order_query", "confidence": 0.0-1.0}}
用户消息:{last_message}
只返回JSON:"""
import json
response = llm.predict(prompt)
result = json.loads(response)
return {
"intent": result["intent"],
"confidence": result["confidence"]
}
# 3. FAQ自动回复节点
def faq_node(state: CustomerServiceState):
"""FAQ自动回复"""
last_message = state["messages"][-1]
# 检索FAQ库
faq_answer = search_faq_database(last_message)
if faq_answer:
return {
"faq_result": faq_answer,
"messages": [AIMessage(content=faq_answer)]
}
else:
return {"faq_result": ""}
# 4. 知识库检索节点
def knowledge_base_node(state: CustomerServiceState):
"""知识库检索"""
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
query = state["messages"][-1]
# 向量检索
embeddings = OpenAIEmbeddings()
vector_store = Chroma(
persist_directory="./customer_service_kb",
embedding_function=embeddings
)
docs = vector_store.similarity_search(query, k=3)
if docs:
# LLM生成回复
llm = ChatOpenAI(model="gpt-4")
context = "\n".join([doc.page_content for doc in docs])
response = llm.predict(f"""
基于以下知识库内容回答用户问题:
{context}
用户问题:{query}
回答:""")
return {
"knowledge_base_result": response,
"messages": [AIMessage(content=response)]
}
else:
return {"knowledge_base_result": ""}
# 5. 订单查询节点
def order_query_node(state: CustomerServiceState):
"""订单查询"""
user_id = state["user_id"]
last_message = state["messages"][-1]
# 从消息中提取订单号
import re
order_numbers = re.findall(r'\d{10,}', last_message)
if order_numbers:
order_no = order_numbers[0]
# 查询订单信息
order_info = query_order_from_db(order_no, user_id)
if order_info:
response = f"""
您的订单信息如下:
订单号:{order_info['order_no']}
状态:{order_info['status']}
下单时间:{order_info['created_at']}
预计送达:{order_info['estimated_delivery']}
"""
return {
"order_info": order_info,
"messages": [AIMessage(content=response)]
}
return {
"messages": [AIMessage(content="未找到订单信息,请提供正确的订单号")]
}
# 6. 投诉处理节点
def complaint_handling_node(state: CustomerServiceState):
"""投诉处理(创建工单)"""
user_id = state["user_id"]
complaint = state["messages"][-1]
# 创建投诉工单
ticket_id = create_ticket({
"user_id": user_id,
"type": "complaint",
"content": complaint,
"priority": "high" if state["user_level"] == "VIP" else "normal",
"status": "pending"
})
response = f"""
已为您创建投诉工单,工单号:{ticket_id}
我们将在24小时内处理,感谢您的反馈。
"""
return {
"ticket_id": ticket_id,
"ticket_status": "pending",
"needs_human": True, # 投诉需要人工处理
"messages": [AIMessage(content=response)]
}
# 7. 转人工节点
def transfer_to_human_node(state: CustomerServiceState):
"""转人工客服"""
# 分配人工客服
agent = assign_customer_service_agent(state["user_level"])
return {
"agent_id": agent["id"],
"messages": [AIMessage(content=f"正在为您转接人工客服 {agent['name']},请稍候...")]
}
# 8. 满意度评价节点
def satisfaction_survey_node(state: CustomerServiceState):
"""满意度调查"""
return {
"messages": [AIMessage(content="请对本次服务进行评价:1-5星")]
}
# 9. 构建客服工作流
workflow = StateGraph(CustomerServiceState)
workflow.add_node("intent", intent_recognition_node)
workflow.add_node("faq", faq_node)
workflow.add_node("kb", knowledge_base_node)
workflow.add_node("order", order_query_node)
workflow.add_node("complaint", complaint_handling_node)
workflow.add_node("transfer", transfer_to_human_node)
workflow.add_node("survey", satisfaction_survey_node)
workflow.set_entry_point("intent")
# 意图路由
def route_by_intent(state: CustomerServiceState) -> str:
intent = state["intent"]
confidence = state["confidence"]
# 低置信度转人工
if confidence < 0.7:
return "transfer"
# 根据意图路由
if intent == "inquiry":
return "faq"
elif intent == "order_query":
return "order"
elif intent == "complaint":
return "complaint"
else:
return "kb"
workflow.add_conditional_edges(
"intent",
route_by_intent,
{
"faq": "faq",
"kb": "kb",
"order": "order",
"complaint": "complaint",
"transfer": "transfer"
}
)
# FAQ后路由
def after_faq(state):
return "survey" if state["faq_result"] else "kb"
workflow.add_conditional_edges(
"faq",
after_faq,
{"survey": "survey", "kb": "kb"}
)
# 知识库后路由
def after_kb(state):
return "survey" if state["knowledge_base_result"] else "transfer"
workflow.add_conditional_edges(
"kb",
after_kb,
{"survey": "survey", "transfer": "transfer"}
)
workflow.add_edge("order", "survey")
workflow.add_edge("complaint", "transfer")
workflow.add_edge("transfer", END)
workflow.add_edge("survey", END)
# 配置检查点
from langgraph.checkpoint.sqlite import SqliteSaver
checkpointer = SqliteSaver.from_conn_string("customer_service.db")
app = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["transfer"] # 转人工前中断
)
# 10. 执行客服对话
config = {"configurable": {"thread_id": "user_12345_session_1"}}
result = app.invoke({
"user_id": "12345",
"user_name": "张三",
"user_level": "VIP",
"messages": [HumanMessage(content="我的订单什么时候能到?订单号1234567890")],
"intent": "",
"confidence": 0.0,
"order_info": {},
"faq_result": "",
"knowledge_base_result": "",
"ticket_id": "",
"ticket_status": "",
"needs_human": False,
"agent_id": "",
"satisfaction": 0
}, config=config)
print("AI回复:", result["messages"][-1].content)
---
02.信创本地LLM
a.Ollama集成
a.功能说明
信创环境下使用Ollama本地部署大模型,避免数据外传。支持Qwen、ChatGLM等国产模型,运行在麒麟操作系统、昇腾NPU上。使用LangChain的Ollama集成,配置本地推理服务。实现离线运行、数据隔离、合规安全。性能优化包括模型量化、批处理推理、KV缓存等。
b.代码示例
---
from langchain.llms import Ollama
from langchain.chat_models import ChatOllama
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
# 1. 配置Ollama本地LLM
llm = Ollama(
model="qwen:7b",
base_url="http://localhost:11434",
temperature=0.7
)
# 2. 在客服节点中使用
def local_llm_intent_node(state: CustomerServiceState):
"""使用本地LLM识别意图"""
last_message = state["messages"][-1]
prompt = f"""
分析用户消息的意图,只返回:咨询、投诉、建议、订单查询 其中之一。
用户消息:{last_message}
意图:"""
intent_text = llm.predict(prompt).strip()
# 映射到英文intent
intent_map = {
"咨询": "inquiry",
"投诉": "complaint",
"建议": "suggestion",
"订单查询": "order_query"
}
return {
"intent": intent_map.get(intent_text, "inquiry"),
"confidence": 0.8
}
# 3. 流式输出
def streaming_llm_response(query: str):
"""流式生成回复"""
llm_stream = ChatOllama(
model="qwen:7b",
base_url="http://localhost:11434",
streaming=True,
callbacks=[StreamingStdOutCallbackHandler()]
)
response = llm_stream.predict(query)
return response
# 4. 批量推理优化
from typing import List
def batch_llm_inference(queries: List[str]):
"""批量推理"""
llm = Ollama(model="qwen:7b")
# Ollama支持批量
results = llm.generate(queries)
return [gen[0].text for gen in results.generations]
# 批量处理多个用户消息
user_queries = [
"我的订单在哪里?",
"如何退货?",
"客服电话是多少?"
]
responses = batch_llm_inference(user_queries)
# 5. 模型切换策略
class AdaptiveLLMRouter:
"""自适应LLM路由"""
def __init__(self):
self.small_model = Ollama(model="qwen:1.8b") # 快速模型
self.large_model = Ollama(model="qwen:14b") # 精确模型
def route_and_generate(self, query: str, complexity: str = "auto"):
"""根据复杂度选择模型"""
if complexity == "auto":
# 自动判断复杂度
complexity = self._estimate_complexity(query)
if complexity == "simple":
return self.small_model.predict(query)
else:
return self.large_model.predict(query)
def _estimate_complexity(self, query: str) -> str:
"""估算查询复杂度"""
# 简单规则
if len(query) < 20:
return "simple"
elif any(kw in query for kw in ["订单", "查询", "电话"]):
return "simple"
else:
return "complex"
router = AdaptiveLLMRouter()
simple_response = router.route_and_generate("订单号是多少?")
complex_response = router.route_and_generate("为什么我的退款申请被拒绝了?具体原因是什么?")
# 6. 信创GPU加速
# 配置昇腾NPU
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["ASCEND_DEVICE_ID"] = "0"
# Ollama会自动使用昇腾NPU
ascend_llm = Ollama(
model="qwen:7b",
base_url="http://localhost:11434",
num_gpu=1 # 使用NPU
)
---
8.3 文档审批流程
01.多级审批
a.流程设计
a.功能说明
企业文档审批需要多级签批,包括部门主管、分管领导、总经理等。每级审批可批准、退回、转交。支持并行会签、串行审批、加签等模式。记录完整审批轨迹,支持催办、超期提醒。信创环境使用国产办公软件、电子签章。
b.代码示例
---
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Annotated
import operator
from datetime import datetime, timedelta
class DocumentApprovalState(TypedDict):
doc_id: str
doc_title: str
doc_content: str
submitter: str
current_level: int
approval_chain: list
approval_records: Annotated[list, operator.add]
status: str
final_approver: str
completed_at: str
def level1_approval_node(state):
"""一级审批(部门主管)"""
if state.get("level1_approved"):
return {
"current_level": 2,
"approval_records": [{
"level": 1,
"approver": "部门主管",
"decision": "批准",
"time": datetime.now().isoformat()
}]
}
else:
return {"status": "rejected"}
def level2_approval_node(state):
"""二级审批(分管领导)"""
if state.get("level2_approved"):
return {
"current_level": 3,
"approval_records": [{
"level": 2,
"approver": "分管领导",
"decision": "批准",
"time": datetime.now().isoformat()
}]
}
else:
return {"status": "rejected"}
def final_approval_node(state):
"""最终审批(总经理)"""
if state.get("final_approved"):
return {
"status": "approved",
"final_approver": "总经理",
"completed_at": datetime.now().isoformat()
}
else:
return {"status": "rejected"}
workflow = StateGraph(DocumentApprovalState)
workflow.add_node("level1", level1_approval_node)
workflow.add_node("level2", level2_approval_node)
workflow.add_node("final", final_approval_node)
workflow.set_entry_point("level1")
workflow.add_edge("level1", "level2")
workflow.add_edge("level2", "final")
workflow.add_edge("final", END)
from langgraph.checkpoint.sqlite import SqliteSaver
checkpointer = SqliteSaver.from_conn_string("approval.db")
app = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["level1", "level2", "final"]
)
---
02.电子签章
a.国密签名
a.功能说明
使用国密SM2算法实现电子签章,符合《电子签名法》要求。审批完成后生成数字签名,验证文档完整性和签署人身份。集成信创电子签章平台,支持可视化签章、时间戳、证书链验证。适配麒麟操作系统的USBKey驱动。
b.代码示例
---
from gmssl import sm2
import base64
def sign_document_sm2(document_hash: str, private_key: str):
"""使用SM2签名文档"""
sm2_crypt = sm2.CryptSM2(private_key=private_key, public_key="")
signature = sm2_crypt.sign(document_hash.encode(), "")
return base64.b64encode(signature).decode()
def verify_signature_sm2(document_hash: str, signature: str, public_key: str):
"""验证SM2签名"""
sm2_crypt = sm2.CryptSM2(private_key="", public_key=public_key)
sig_bytes = base64.b64decode(signature)
return sm2_crypt.verify(sig_bytes, document_hash.encode())
def approval_with_signature_node(state):
"""审批并签名"""
import hashlib
doc_hash = hashlib.sha256(state["doc_content"].encode()).hexdigest()
private_key = get_approver_private_key(state["approver"])
signature = sign_document_sm2(doc_hash, private_key)
return {
"approval_records": [{
"approver": state["approver"],
"signature": signature,
"timestamp": datetime.now().isoformat()
}]
}
---
8.4 智能决策系统
01.决策树引擎
a.规则引擎
a.功能说明
构建基于规则的智能决策系统,支持复杂业务逻辑。使用LangGraph实现决策树、决策表、评分卡等模型。动态加载规则配置,支持规则热更新。集成LLM增强决策能力,处理非结构化信息。适用于信贷审批、风险评估、资源分配等场景。
b.代码示例
---
from langgraph.graph import StateGraph, END
class DecisionState(TypedDict):
input_data: dict
rules_result: dict
llm_analysis: str
final_decision: str
confidence: float
def rule_engine_node(state):
"""规则引擎"""
data = state["input_data"]
score = 0
# 规则1:收入评估
if data["income"] > 10000:
score += 30
elif data["income"] > 5000:
score += 15
# 规则2:信用评分
if data["credit_score"] > 700:
score += 40
elif data["credit_score"] > 600:
score += 20
# 规则3:负债比
if data["debt_ratio"] < 0.3:
score += 30
elif data["debt_ratio"] < 0.5:
score += 15
return {"rules_result": {"score": score, "max": 100}}
def llm_analysis_node(state):
"""LLM辅助分析"""
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model="gpt-4")
data = state["input_data"]
prompt = f"""
分析申请人资质:
收入:{data['income']}
信用分:{data['credit_score']}
负债比:{data['debt_ratio']}
工作年限:{data['work_years']}
给出风险评估和建议:"""
analysis = llm.predict(prompt)
return {"llm_analysis": analysis}
def final_decision_node(state):
"""最终决策"""
score = state["rules_result"]["score"]
if score >= 80:
decision = "批准"
confidence = 0.95
elif score >= 60:
decision = "人工审核"
confidence = 0.7
else:
decision = "拒绝"
confidence = 0.9
return {
"final_decision": decision,
"confidence": confidence
}
workflow = StateGraph(DecisionState)
workflow.add_node("rules", rule_engine_node)
workflow.add_node("llm", llm_analysis_node)
workflow.add_node("decision", final_decision_node)
workflow.set_entry_point("rules")
workflow.add_edge("rules", "llm")
workflow.add_edge("llm", "decision")
workflow.add_edge("decision", END)
---
02.A/B测试
a.策略对比
a.功能说明
使用LangGraph实现决策策略的A/B测试。同时运行多个决策版本,收集结果对比效果。支持灰度发布、流量分配、效果统计。适用于优化决策模型、测试新规则、验证LLM提示词等场景。
b.代码示例
---
import random
class ABTestState(TypedDict):
input: dict
strategy: str
result_a: str
result_b: str
chosen_result: str
def strategy_selector_node(state):
"""策略选择"""
strategy = random.choice(["A", "B"])
return {"strategy": strategy}
def strategy_a_node(state):
"""策略A"""
result = process_with_strategy_a(state["input"])
return {"result_a": result}
def strategy_b_node(state):
"""策略B"""
result = process_with_strategy_b(state["input"])
return {"result_b": result}
def result_router(state):
return state["strategy"].lower()
workflow = StateGraph(ABTestState)
workflow.add_node("selector", strategy_selector_node)
workflow.add_node("strategy_a", strategy_a_node)
workflow.add_node("strategy_b", strategy_b_node)
workflow.set_entry_point("selector")
workflow.add_conditional_edges(
"selector",
result_router,
{"a": "strategy_a", "b": "strategy_b"}
)
workflow.add_edge("strategy_a", END)
workflow.add_edge("strategy_b", END)
app = workflow.compile()
# 运行AB测试
results_a = []
results_b = []
for i in range(1000):
result = app.invoke({"input": {"data": i}})
if result["strategy"] == "A":
results_a.append(result["result_a"])
else:
results_b.append(result["result_b"])
print(f"策略A:{len(results_a)}次,策略B:{len(results_b)}次")
---
9 调试与监控
9.1 图可视化
01.图结构展示
a.Mermaid导出
a.功能说明
LangGraph支持导出为Mermaid格式,可视化图的结构。展示节点、边、条件路由等。帮助理解工作流逻辑,进行架构Review。使用get_graph()获取图定义,转换为Mermaid语法。可嵌入Markdown文档、Jupyter Notebook等。
b.代码示例
---
from langgraph.graph import StateGraph, END
workflow = StateGraph(State)
# ... 添加节点和边 ...
app = workflow.compile()
# 获取图结构
graph = app.get_graph()
# 导出为Mermaid
mermaid_code = graph.draw_mermaid()
print(mermaid_code)
# 输出示例:
# graph TD
# __start__ --> analyze
# analyze --> route
# route --> |high| high_path
# route --> |low| low_path
# high_path --> __end__
# low_path --> __end__
# 在Jupyter中显示
from IPython.display import display, Image
png = graph.draw_mermaid_png()
display(Image(png))
---
02.执行可视化
a.实时图
a.功能说明
可视化图的执行过程,高亮当前节点、已完成节点、待执行节点。实时更新执行状态,展示数据流动。使用流式输出或事件监听获取执行信息,动态更新可视化界面。适用于调试、演示、监控等场景。
b.代码示例
---
import asyncio
from rich.console import Console
from rich.tree import Tree
async def visualize_execution():
"""可视化执行过程"""
console = Console()
tree = Tree("执行流程")
node_trees = {}
async for event in app.astream_events({"input": "test"}, version="v1"):
if event["event"] == "on_chain_start":
node_name = event["name"]
node_tree = tree.add(f"[yellow]▶ {node_name}")
node_trees[node_name] = node_tree
elif event["event"] == "on_chain_end":
node_name = event["name"]
if node_name in node_trees:
node_trees[node_name].label = f"[green]✓ {node_name}"
console.clear()
console.print(tree)
await asyncio.sleep(0.1)
asyncio.run(visualize_execution())
---
9.2 执行追踪
01.LangSmith集成
a.追踪配置
a.功能说明
使用LangSmith追踪LangGraph执行,记录每个节点的输入输出、LLM调用、耗时等。配置LANGCHAIN_TRACING_V2环境变量启用追踪。在LangSmith平台查看完整执行轨迹,分析性能瓶颈,定位错误。支持自定义标签、过滤、搜索。
b.代码示例
---
import os
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = "your_api_key"
os.environ["LANGCHAIN_PROJECT"] = "my_project"
# 执行会自动追踪
result = app.invoke({"input": "test"})
# 在LangSmith平台查看追踪
# https://smith.langchain.com
# 自定义追踪标签
from langchain.callbacks import LangChainTracer
tracer = LangChainTracer(
project_name="my_project",
tags=["production", "v2.0"]
)
result = app.invoke(
{"input": "test"},
config={"callbacks": [tracer]}
)
---
02.自定义日志
a.结构化日志
a.功能说明
实现自定义日志记录,捕获关键执行信息。使用Python logging模块,配置JSON格式输出。记录节点开始/结束、状态变化、错误异常等。支持日志聚合、检索、告警。适配ELK、Splunk等日志平台。
b.代码示例
---
import logging
import json
from datetime import datetime
class StructuredLogger:
def __init__(self, name: str):
self.logger = logging.getLogger(name)
self.logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(message)s'))
self.logger.addHandler(handler)
def log_node_start(self, node_name: str, state: dict):
self.logger.info(json.dumps({
"event": "node_start",
"node": node_name,
"timestamp": datetime.now().isoformat(),
"state_keys": list(state.keys())
}, ensure_ascii=False))
def log_node_end(self, node_name: str, duration: float):
self.logger.info(json.dumps({
"event": "node_end",
"node": node_name,
"duration_ms": duration * 1000,
"timestamp": datetime.now().isoformat()
}, ensure_ascii=False))
logger = StructuredLogger("langgraph")
def logged_node(state):
import time
start = time.time()
logger.log_node_start("my_node", state)
result = process(state)
logger.log_node_end("my_node", time.time() - start)
return result
---
9.3 性能分析
01.性能指标
a.耗时统计
a.功能说明
统计各节点执行时间,识别性能瓶颈。记录总耗时、节点耗时、LLM调用时间、数据库查询时间等。计算百分位数、平均值、最大值。生成性能报告,指导优化。使用Python time模块或cProfile进行性能分析。
b.代码示例
---
import time
from collections import defaultdict
class PerformanceMonitor:
def __init__(self):
self.node_times = defaultdict(list)
self.start_times = {}
async def track_performance(self, event_stream):
async for event in event_stream:
event_type = event["event"]
name = event.get("name", "")
if event_type.endswith("_start"):
self.start_times[name] = time.time()
elif event_type.endswith("_end"):
if name in self.start_times:
duration = time.time() - self.start_times[name]
self.node_times[name].append(duration)
del self.start_times[name]
yield event
def print_report(self):
import statistics
print("\n性能报告:")
for node, times in self.node_times.items():
print(f"{node}:")
print(f" 调用次数:{len(times)}")
print(f" 平均耗时:{statistics.mean(times):.3f}秒")
print(f" 最大耗时:{max(times):.3f}秒")
print(f" P95:{statistics.quantiles(times, n=20)[18]:.3f}秒")
monitor = PerformanceMonitor()
async for event in monitor.track_performance(
app.astream_events({"input": "test"}, version="v1")
):
pass
monitor.print_report()
---
02.资源监控
a.内存CPU
a.功能说明
监控图执行过程中的资源使用,包括内存、CPU、GPU等。使用psutil库获取系统资源信息。追踪内存泄漏、CPU峰值、GPU利用率。设置资源告警阈值,防止资源耗尽。适用于生产环境监控、容量规划。
b.代码示例
---
import psutil
import os
class ResourceMonitor:
def __init__(self):
self.process = psutil.Process(os.getpid())
self.memory_samples = []
self.cpu_samples = []
def sample(self):
mem_mb = self.process.memory_info().rss / 1024 / 1024
cpu_percent = self.process.cpu_percent()
self.memory_samples.append(mem_mb)
self.cpu_samples.append(cpu_percent)
return {"memory_mb": mem_mb, "cpu_percent": cpu_percent}
def report(self):
import statistics
return {
"memory": {
"avg": statistics.mean(self.memory_samples),
"max": max(self.memory_samples),
"samples": len(self.memory_samples)
},
"cpu": {
"avg": statistics.mean(self.cpu_samples),
"max": max(self.cpu_samples)
}
}
monitor = ResourceMonitor()
async for event in app.astream_events({"input": "test"}, version="v1"):
if event["event"] == "on_chain_start":
monitor.sample()
print(monitor.report())
---
10 最佳实践
10.1 状态设计
01.状态结构
a.字段设计
a.功能说明
合理设计状态结构是构建高效LangGraph的基础。使用TypedDict定义类型,明确字段含义。区分必需字段和可选字段。使用Annotated配置合并策略(add、replace等)。避免嵌套过深、字段过多。状态应自包含,减少外部依赖。
b.代码示例
---
from typing import TypedDict, Annotated, Optional, List
import operator
# 良好的状态设计
class WellDesignedState(TypedDict):
# 输入数据(必需)
input: str
# 中间结果(可选)
intermediate_results: Annotated[List[dict], operator.add]
# 最终输出(可选)
final_output: Optional[str]
# 元数据
metadata: dict
# 避免的设计
class PoorState(TypedDict):
data: dict # 太模糊
result1: str # 命名不清
result2: str
result3: str
temp_var: str # 临时变量不应在状态中
---
02.状态管理
a.状态清理
a.功能说明
大状态对象影响性能和存储。定期清理不再需要的字段,保持状态精简。使用状态转换函数删除临时数据。区分持久化状态和临时状态。实现状态压缩、归档策略。适用于长时间运行、大数据量的工作流。
b.代码示例
---
def cleanup_state_node(state: State):
"""清理状态"""
# 删除大对象
cleaned = {k: v for k, v in state.items()
if k not in ["temp_data", "cache"]}
# 只保留必要字段
return {
"final_result": state["final_result"],
"metadata": {"processed": True}
}
---
10.2 错误处理
01.异常捕获
a.节点错误处理
a.功能说明
在节点内部捕获异常,避免整个图崩溃。使用try-except处理预期错误,记录日志,返回错误状态。区分可恢复错误和致命错误。实现重试机制、降级策略、错误通知。确保错误信息清晰,便于排查。
b.代码示例
---
from tenacity import retry, stop_after_attempt, wait_exponential
import logging
logger = logging.getLogger(__name__)
def safe_node(state: State):
"""安全的节点实现"""
try:
result = risky_operation(state["input"])
return {"output": result}
except ValueError as e:
logger.error(f"值错误:{e}")
return {
"output": "",
"error": str(e),
"needs_retry": True
}
except Exception as e:
logger.exception("未知错误")
return {
"output": "",
"error": str(e),
"fatal": True
}
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10)
)
def retry_node(state: State):
"""带重试的节点"""
return risky_operation(state["input"])
---
02.降级策略
a.备用方案
a.功能说明
主路径失败时,自动切换到备用方案。使用条件边实现降级路由。备用方案可以是简化逻辑、缓存数据、默认值等。记录降级事件,用于后续分析。提升系统鲁棒性,确保核心功能可用。
b.代码示例
---
def primary_node(state):
try:
return {"result": expensive_operation(state["input"])}
except Exception:
return {"needs_fallback": True}
def fallback_node(state):
return {"result": cheap_operation(state["input"])}
def route_with_fallback(state):
return "fallback" if state.get("needs_fallback") else "success"
workflow.add_conditional_edges(
"primary",
route_with_fallback,
{"success": END, "fallback": "fallback"}
)
---
10.3 性能优化
01.缓存策略
a.结果缓存
a.功能说明
缓存节点输出,避免重复计算。使用LRU缓存、Redis等实现。设置合理的缓存失效策略。特别适合LLM调用、数据库查询等耗时操作。注意缓存一致性,避免脏数据。监控缓存命中率,调整缓存大小。
b.代码示例
---
from functools import lru_cache
import redis
import hashlib
import json
@lru_cache(maxsize=128)
def cached_llm_call(prompt: str):
"""LLM调用缓存"""
return llm.predict(prompt)
class RedisCache:
def __init__(self):
self.redis = redis.Redis(host='localhost', port=6379)
def get_or_compute(self, key: str, compute_func):
cached = self.redis.get(key)
if cached:
return json.loads(cached)
result = compute_func()
self.redis.setex(key, 3600, json.dumps(result))
return result
cache = RedisCache()
def cached_node(state):
key = hashlib.md5(state["input"].encode()).hexdigest()
result = cache.get_or_compute(
key,
lambda: expensive_operation(state["input"])
)
return {"output": result}
---
02.并行优化
a.异步执行
a.功能说明
使用异步节点提升并发性能。将同步I/O改为异步I/O,提高吞吐量。使用asyncio.gather并行执行独立任务。注意异步上下文管理、异常处理。适用于API调用、数据库查询、文件操作等I/O密集型任务。
b.代码示例
---
import asyncio
async def async_node(state: State):
"""异步节点"""
results = await asyncio.gather(
async_api_call_1(state["input"]),
async_api_call_2(state["input"]),
async_db_query(state["input"])
)
return {
"api1": results[0],
"api2": results[1],
"db": results[2]
}
async def parallel_processing(items: list):
"""并行处理多个项目"""
tasks = [process_item(item) for item in items]
results = await asyncio.gather(*tasks)
return results
# 使用
result = asyncio.run(parallel_processing(large_list))
---