1 快速开始
1.1 安装部署
01.环境准备
a.安装LangServe
a.功能说明
LangServe是LangChain官方的部署框架,将LangChain应用快速部署为REST API服务。基于FastAPI构建,支持高性能异步处理。自动生成OpenAPI文档和交互式Playground。内置流式响应、批量处理、错误处理等功能。使用pip安装langserve及相关依赖。安装简单,集成LangChain生态,是生产部署的首选方案。
b.代码示例
---
# 1. 安装LangServe
# pip install "langserve[all]"
# 或者分别安装
# pip install langserve
# pip install "langserve[server]" # 服务端
# pip install "langserve[client]" # 客户端
# 2. 安装依赖
# pip install langchain
# pip install fastapi
# pip install uvicorn[standard]
# 3. 验证安装
import langserve
print(f"LangServe版本:{langserve.__version__}")
from fastapi import FastAPI
print("✓ FastAPI已安装")
import uvicorn
print("✓ Uvicorn已安装")
# 4. 完整依赖列表
# requirements.txt
"""
langchain>=0.1.0
langserve>=0.0.30
fastapi>=0.104.0
uvicorn[standard]>=0.24.0
pydantic>=2.0.0
httpx>=0.25.0
sse-starlette>=1.6.5
"""
# 5. 开发环境安装
# python -m venv venv
# source venv/bin/activate # Linux/Mac
# venv\Scripts\activate # Windows
# pip install -r requirements.txt
# 6. 信创环境安装
# 使用国内镜像源
# pip install -i https://pypi.tuna.tsinghua.edu.cn/simple langserve[all]
# 麒麟系统安装
# sudo yum install python3-devel gcc
# pip3 install langserve[all]
# 验证安装
python3 -c "import langserve; print('✓ LangServe安装成功')"
---
b.创建项目
a.功能说明
创建LangServe项目目录结构,组织代码文件。主文件(如server.py)包含FastAPI应用和路由配置。chains目录存放各个Chain实现。模型配置、环境变量等分离管理。合理的项目结构便于开发、测试和部署。
b.代码示例
---
# 1. 项目结构
"""
langserve-app/
├── server.py # 主服务文件
├── chains/ # Chain实现
│ ├── __init__.py
│ ├── qa_chain.py
│ └── summary_chain.py
├── config.py # 配置
├── requirements.txt # 依赖
└── .env # 环境变量
"""
# 2. 创建项目目录
import os
project_dirs = [
"langserve-app",
"langserve-app/chains"
]
for dir_path in project_dirs:
os.makedirs(dir_path, exist_ok=True)
print("✓ 项目目录已创建")
# 3. 环境变量文件
# .env
"""
OPENAI_API_KEY=your_openai_key
LANGCHAIN_API_KEY=your_langsmith_key
LANGCHAIN_TRACING_V2=true
LANGCHAIN_PROJECT=langserve_dev
"""
# 4. 配置文件
# config.py
"""
import os
from dotenv import load_dotenv
load_dotenv()
class Config:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
SERVER_HOST = os.getenv("SERVER_HOST", "0.0.0.0")
SERVER_PORT = int(os.getenv("SERVER_PORT", 8000))
ENABLE_TRACING = os.getenv("LANGCHAIN_TRACING_V2", "false") == "true"
config = Config()
"""
# 5. 初始化文件
# chains/__init__.py
"""
from .qa_chain import qa_chain
from .summary_chain import summary_chain
__all__ = ["qa_chain", "summary_chain"]
"""
# 6. 快速脚手架
def create_langserve_project(project_name: str):
"""创建LangServe项目脚手架"""
import os
base_dir = project_name
# 创建目录
os.makedirs(f"{base_dir}/chains", exist_ok=True)
# 创建主文件
server_code = '''
from fastapi import FastAPI
from langserve import add_routes
from chains import qa_chain
app = FastAPI(
title="LangServe API",
version="1.0",
description="LangChain应用API服务"
)
# 添加路由
add_routes(app, qa_chain, path="/qa")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
'''
with open(f"{base_dir}/server.py", "w") as f:
f.write(server_code)
# 创建示例Chain
chain_code = '''
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
llm = ChatOpenAI(model="gpt-3.5-turbo")
prompt = ChatPromptTemplate.from_template(
"回答问题:{question}"
)
qa_chain = prompt | llm
'''
with open(f"{base_dir}/chains/qa_chain.py", "w") as f:
f.write(chain_code)
# 创建requirements.txt
with open(f"{base_dir}/requirements.txt", "w") as f:
f.write("langserve[all]\\nlangchain\\nopenai\\n")
print(f"✓ 项目'{project_name}'已创建")
# 使用
# create_langserve_project("my-langserve-app")
---
02.快速启动
a.最小示例
a.功能说明
创建最简单的LangServe服务,理解基本概念。导入FastAPI和add_routes。创建一个简单的Chain(如LLM)。使用add_routes将Chain添加到应用。运行服务器访问API。最小示例是学习LangServe的起点。
b.代码示例
---
# server.py
from fastapi import FastAPI
from langchain.chat_models import ChatOpenAI
from langserve import add_routes
# 1. 创建FastAPI应用
app = FastAPI(
title="LangServe Demo",
version="1.0",
description="最简单的LangServe服务"
)
# 2. 创建Chain
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
# 3. 添加路由
add_routes(
app,
llm,
path="/chat"
)
# 4. 运行服务器
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=8000
)
# 启动服务
# python server.py
# 5. 访问服务
# 在浏览器访问:
# http://localhost:8000/docs # OpenAPI文档
# http://localhost:8000/chat/playground/ # 交互界面
# 6. 测试API
# 使用curl测试
"""
curl -X POST "http://localhost:8000/chat/invoke" \\
-H "Content-Type: application/json" \\
-d '{"input": "你好"}'
"""
# 使用Python测试
import requests
response = requests.post(
"http://localhost:8000/chat/invoke",
json={"input": "什么是LangServe?"}
)
print(response.json())
# 输出:{"output": {"content": "LangServe是..."}}
# 7. 添加健康检查
@app.get("/health")
def health_check():
"""健康检查端点"""
return {"status": "healthy", "service": "langserve"}
# 8. 添加多个Chain
from langchain.prompts import ChatPromptTemplate
# QA Chain
qa_prompt = ChatPromptTemplate.from_template("回答:{question}")
qa_chain = qa_prompt | llm
add_routes(app, qa_chain, path="/qa")
# 摘要Chain
summary_prompt = ChatPromptTemplate.from_template("总结:{text}")
summary_chain = summary_prompt | llm
add_routes(app, summary_chain, path="/summary")
# 现在有3个端点:
# /chat/invoke
# /qa/invoke
# /summary/invoke
---
b.启动服务
a.功能说明
使用Uvicorn启动LangServe服务器,处理HTTP请求。配置主机、端口、工作进程数等参数。支持热重载(开发模式)和生产模式。监听日志输出,查看请求处理情况。启动服务是部署的关键步骤。
b.代码示例
---
import uvicorn
from fastapi import FastAPI
app = FastAPI()
# 1. 基础启动
# 在server.py末尾
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
# 命令行启动
# python server.py
# 2. 开发模式(热重载)
if __name__ == "__main__":
uvicorn.run(
"server:app", # 模块:应用对象
host="0.0.0.0",
port=8000,
reload=True, # 代码变更自动重启
log_level="info"
)
# 或使用命令行
# uvicorn server:app --reload --port 8000
# 3. 生产模式(多进程)
if __name__ == "__main__":
uvicorn.run(
"server:app",
host="0.0.0.0",
port=8000,
workers=4, # 4个工作进程
log_level="warning"
)
# 4. HTTPS配置
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=443,
ssl_keyfile="/path/to/key.pem",
ssl_certfile="/path/to/cert.pem"
)
# 5. 使用Gunicorn(生产推荐)
# pip install gunicorn
# gunicorn配置文件 gunicorn.conf.py
"""
bind = "0.0.0.0:8000"
workers = 4
worker_class = "uvicorn.workers.UvicornWorker"
timeout = 120
keepalive = 5
"""
# 启动命令
# gunicorn server:app -c gunicorn.conf.py
# 6. systemd服务(Linux)
# /etc/systemd/system/langserve.service
"""
[Unit]
Description=LangServe API Service
After=network.target
[Service]
Type=simple
User=www-data
WorkingDirectory=/opt/langserve
Environment="PATH=/opt/langserve/venv/bin"
ExecStart=/opt/langserve/venv/bin/uvicorn server:app --host 0.0.0.0 --port 8000
Restart=always
[Install]
WantedBy=multi-user.target
"""
# 启动服务
# sudo systemctl start langserve
# sudo systemctl enable langserve # 开机自启
# sudo systemctl status langserve # 查看状态
# 7. 信创环境启动
# 麒麟系统服务配置
"""
[Unit]
Description=LangServe on Kylin
After=network.target
[Service]
Type=simple
User=langserve
WorkingDirectory=/home/langserve/app
ExecStart=/usr/bin/python3 -m uvicorn server:app --host 0.0.0.0 --port 8000
Restart=always
RestartSec=5
[Install]
WantedBy=multi-user.target
"""
# 8. 启动脚本
# start.sh
"""
#!/bin/bash
# 激活虚拟环境
source venv/bin/activate
# 设置环境变量
export OPENAI_API_KEY="your_key"
export LANGCHAIN_TRACING_V2="true"
# 启动服务
uvicorn server:app \\
--host 0.0.0.0 \\
--port 8000 \\
--workers 4 \\
--log-level info
"""
# chmod +x start.sh
# ./start.sh
# 9. Docker启动(见第6章详细内容)
# docker run -p 8000:8000 -e OPENAI_API_KEY=xxx langserve-app
---
1.2 创建服务
01.FastAPI应用
a.创建应用实例
a.功能说明
LangServe基于FastAPI构建,首先创建FastAPI应用实例。配置应用的标题、版本、描述等元数据,这些信息会显示在OpenAPI文档中。可以配置中间件、异常处理器、生命周期事件等。FastAPI应用是LangServe的基础容器。
b.代码示例
---
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# 1. 基础应用
app = FastAPI()
# 2. 完整配置
app = FastAPI(
title="LangServe API",
version="1.0.0",
description="基于LangChain的AI服务API",
docs_url="/docs", # Swagger UI
redoc_url="/redoc", # ReDoc
openapi_url="/openapi.json"
)
# 3. 添加CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境应限制来源
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# 4. 生命周期事件
@app.on_event("startup")
async def startup_event():
"""应用启动时执行"""
print("🚀 LangServe启动中...")
# 初始化数据库连接
# 预加载模型
@app.on_event("shutdown")
async def shutdown_event():
"""应用关闭时执行"""
print("👋 LangServe关闭中...")
# 清理资源
# 关闭连接
# 5. 全局异常处理
from fastapi import Request
from fastapi.responses import JSONResponse
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""全局异常捕获"""
return JSONResponse(
status_code=500,
content={
"error": str(exc),
"path": request.url.path
}
)
# 6. 自定义中间件
from starlette.middleware.base import BaseHTTPMiddleware
import time
class TimingMiddleware(BaseHTTPMiddleware):
"""请求计时中间件"""
async def dispatch(self, request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
app.add_middleware(TimingMiddleware)
# 7. 健康检查
@app.get("/health")
def health_check():
"""健康检查端点"""
return {
"status": "healthy",
"service": "langserve",
"version": "1.0.0"
}
@app.get("/")
def root():
"""根端点"""
return {
"message": "Welcome to LangServe API",
"docs": "/docs"
}
# 8. 信创环境配置
# 达梦数据库连接
import dmPython
@app.on_event("startup")
async def init_db():
"""初始化达梦数据库"""
try:
conn = dmPython.connect(
user="SYSDBA",
password="SYSDBA",
server="localhost",
port=5236
)
app.state.db = conn
print("✓ 达梦数据库连接成功")
except Exception as e:
print(f"✗ 数据库连接失败:{e}")
@app.on_event("shutdown")
async def close_db():
"""关闭数据库连接"""
if hasattr(app.state, "db"):
app.state.db.close()
---
b.添加路由组
a.功能说明
使用FastAPI的APIRouter组织路由,实现模块化管理。将不同功能的端点分组,便于维护和扩展。可以为路由组添加统一的前缀、标签、依赖等。路由组提高代码的可维护性和可读性。
b.代码示例
---
from fastapi import APIRouter, Depends
# 1. 创建路由组
api_router = APIRouter(prefix="/api/v1", tags=["AI Services"])
# 2. 添加端点
@api_router.get("/status")
def get_status():
"""获取服务状态"""
return {"status": "running"}
@api_router.post("/chat")
def chat(message: str):
"""聊天接口"""
return {"response": f"收到消息:{message}"}
# 3. 注册路由组
from fastapi import FastAPI
app = FastAPI()
app.include_router(api_router)
# 现在可以访问:
# /api/v1/status
# /api/v1/chat
# 4. 多个路由组
# 用户相关
user_router = APIRouter(prefix="/users", tags=["Users"])
@user_router.get("/{user_id}")
def get_user(user_id: int):
return {"user_id": user_id}
# 管理相关
admin_router = APIRouter(prefix="/admin", tags=["Admin"])
@admin_router.get("/stats")
def get_stats():
return {"total_requests": 1000}
# 注册所有路由
app.include_router(api_router)
app.include_router(user_router)
app.include_router(admin_router)
# 5. 依赖注入
from fastapi import Header, HTTPException
def verify_token(authorization: str = Header(None)):
"""验证Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未授权")
return authorization[7:]
# 应用到路由组
protected_router = APIRouter(
prefix="/protected",
tags=["Protected"],
dependencies=[Depends(verify_token)]
)
@protected_router.get("/data")
def get_protected_data():
"""需要认证的端点"""
return {"data": "sensitive"}
app.include_router(protected_router)
# 6. 版本管理
# v1路由
v1_router = APIRouter(prefix="/api/v1", tags=["V1"])
@v1_router.get("/model")
def get_model_v1():
return {"model": "gpt-3.5-turbo"}
# v2路由
v2_router = APIRouter(prefix="/api/v2", tags=["V2"])
@v2_router.get("/model")
def get_model_v2():
return {"model": "gpt-4", "version": "2"}
app.include_router(v1_router)
app.include_router(v2_router)
# 7. 完整示例
from fastapi import FastAPI, APIRouter
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
app = FastAPI(title="LangServe Multi-Router")
# AI服务路由
ai_router = APIRouter(prefix="/ai", tags=["AI"])
llm = ChatOpenAI()
add_routes(ai_router, llm, path="/chat")
# 工具路由
tools_router = APIRouter(prefix="/tools", tags=["Tools"])
@tools_router.post("/translate")
def translate(text: str, target_lang: str):
return {"translated": f"[{target_lang}]{text}"}
# 注册所有路由
app.include_router(ai_router)
app.include_router(tools_router)
# 访问:
# /ai/chat/invoke
# /tools/translate
---
02.Chain部署
a.部署Runnable
a.功能说明
将LangChain的Runnable对象部署为API端点。Runnable是LangChain的统一接口,Chain、LLM、Prompt等都实现了Runnable。使用add_routes函数将Runnable添加到FastAPI应用。自动生成invoke、batch、stream等端点。
b.代码示例
---
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
# 1. 部署LLM
llm = ChatOpenAI(model="gpt-3.5-turbo")
add_routes(
app,
llm,
path="/llm"
)
# 自动生成端点:
# POST /llm/invoke - 单次调用
# POST /llm/batch - 批量调用
# POST /llm/stream - 流式调用
# POST /llm/stream_log - 流式日志
# GET /llm/playground - 交互界面
# 2. 部署Chain
prompt = ChatPromptTemplate.from_template("翻译为英文:{text}")
chain = prompt | llm
add_routes(app, chain, path="/translate")
# 3. 部署带配置的Runnable
from langchain.schema.runnable import RunnableConfig
configurable_llm = ChatOpenAI(
model="gpt-3.5-turbo"
).configurable_fields(
temperature=RunnableConfig(
id="temperature",
name="Temperature",
description="LLM温度参数"
)
)
add_routes(app, configurable_llm, path="/configurable_llm")
# 调用时可以配置
# {"input": "你好", "config": {"configurable": {"temperature": 0.9}}}
# 4. 部署多个Chain
# QA Chain
qa_prompt = ChatPromptTemplate.from_template("问题:{question}")
qa_chain = qa_prompt | llm
add_routes(app, qa_chain, path="/qa")
# 摘要Chain
summary_prompt = ChatPromptTemplate.from_template("总结:{text}")
summary_chain = summary_prompt | llm
add_routes(app, summary_chain, path="/summary")
# 情感分析Chain
sentiment_prompt = ChatPromptTemplate.from_template("分析情感:{text}")
sentiment_chain = sentiment_prompt | llm
add_routes(app, sentiment_chain, path="/sentiment")
# 5. 部署复杂Chain
from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
class TranslationResult(BaseModel):
original: str = Field(description="原文")
translated: str = Field(description="译文")
language: str = Field(description="目标语言")
parser = PydanticOutputParser(pydantic_object=TranslationResult)
complex_prompt = ChatPromptTemplate.from_template(
"将'{text}'翻译为{target_lang}\\n{format_instructions}"
).partial(format_instructions=parser.get_format_instructions())
complex_chain = complex_prompt | llm | parser
add_routes(app, complex_chain, path="/advanced_translate")
# 6. 信创环境部署
# 使用Ollama本地模型
from langchain.llms import Ollama
local_llm = Ollama(model="qwen:7b", base_url="http://localhost:11434")
local_prompt = ChatPromptTemplate.from_template("回答:{question}")
local_chain = local_prompt | local_llm
add_routes(app, local_chain, path="/local_qa")
# 麒麟系统优化
import os
os.environ["OLLAMA_HOST"] = "localhost:11434"
os.environ["OLLAMA_NUM_PARALLEL"] = "2"
print("✓ 本地模型部署成功")
---
b.配置选项
a.功能说明
add_routes提供丰富的配置选项,控制端点行为。启用或禁用特定端点(invoke、batch、stream)。配置输入输出类型、验证规则。设置并发限制、超时时间等。合理配置提升服务性能和安全性。
b.代码示例
---
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
from pydantic import BaseModel, Field
llm = ChatOpenAI()
# 1. 基础配置
add_routes(
app,
llm,
path="/chat",
enabled_endpoints=["invoke", "stream"], # 只启用部分端点
disabled_endpoints=["batch"] # 禁用批量端点
)
# 2. 输入输出类型
class ChatInput(BaseModel):
message: str = Field(..., description="用户消息")
context: str = Field(default="", description="上下文")
class ChatOutput(BaseModel):
response: str = Field(..., description="AI回复")
tokens: int = Field(..., description="Token数量")
add_routes(
app,
llm,
path="/typed_chat",
input_type=ChatInput,
output_type=ChatOutput
)
# 3. 配置文档
add_routes(
app,
llm,
path="/documented",
config_keys=["tags", "metadata"], # 允许的配置键
include_callback_events=True # 包含回调事件
)
# 4. 并发限制
from langserve import add_routes
import asyncio
# 使用信号量限制并发
semaphore = asyncio.Semaphore(5) # 最多5个并发
async def limited_llm(input_data):
async with semaphore:
return await llm.ainvoke(input_data)
# 注意:需要包装为Runnable
from langchain.schema.runnable import RunnableLambda
limited_runnable = RunnableLambda(limited_llm)
add_routes(app, limited_runnable, path="/limited")
# 5. 超时配置
from langchain.callbacks import get_openai_callback
import asyncio
async def timeout_wrapper(input_data):
try:
return await asyncio.wait_for(
llm.ainvoke(input_data),
timeout=30.0 # 30秒超时
)
except asyncio.TimeoutError:
return {"error": "请求超时"}
timeout_runnable = RunnableLambda(timeout_wrapper)
add_routes(app, timeout_runnable, path="/timeout")
# 6. 完整配置示例
add_routes(
app,
llm,
path="/full_config",
# 端点配置
enabled_endpoints=["invoke", "batch", "stream"],
# 类型配置
input_type=dict,
output_type=dict,
# 文档配置
config_keys=["tags", "metadata", "callbacks"],
# Playground配置
playground_type="default",
# 其他配置
include_callback_events=False
)
# 7. 信创环境配置
# 使用达梦数据库存储请求日志
import dmPython
class DmLogger:
def __init__(self):
self.conn = dmPython.connect(
user="SYSDBA",
password="SYSDBA",
server="localhost",
port=5236
)
def log_request(self, input_data, output_data):
cursor = self.conn.cursor()
cursor.execute(
"INSERT INTO request_logs (input, output, created_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
(str(input_data), str(output_data))
)
self.conn.commit()
logger = DmLogger()
# 包装Chain添加日志
def logged_chain(input_data):
result = llm.invoke(input_data)
logger.log_request(input_data, result)
return result
logged_runnable = RunnableLambda(logged_chain)
add_routes(app, logged_runnable, path="/logged")
---
1.3 添加路由
01.add_routes函数
a.基本用法
a.功能说明
add_routes是LangServe的核心函数,将Runnable添加为API端点。自动生成RESTful API路由(invoke、batch、stream等)。生成OpenAPI文档和Playground交互界面。简化LangChain应用的API部署流程。
b.代码示例
---
from fastapi import FastAPI
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
app = FastAPI()
# 1. 最简单用法
llm = ChatOpenAI(model="gpt-3.5-turbo")
add_routes(app, llm, path="/chat")
# 生成的端点:
# POST /chat/invoke - 单次调用
# POST /chat/batch - 批量调用
# POST /chat/stream - 流式调用
# POST /chat/stream_log - 流式日志
# GET /chat/input_schema - 输入schema
# GET /chat/output_schema- 输出schema
# GET /chat/config_schema- 配置schema
# GET /chat/playground - Playground界面
# 2. 添加Chain
prompt = ChatPromptTemplate.from_template("回答问题:{question}")
chain = prompt | llm
add_routes(app, chain, path="/qa")
# 3. 添加多个路由
# 翻译服务
translate_prompt = ChatPromptTemplate.from_template("翻译为{lang}:{text}")
translate_chain = translate_prompt | llm
add_routes(app, translate_chain, path="/translate")
# 摘要服务
summary_prompt = ChatPromptTemplate.from_template("总结以下内容:{content}")
summary_chain = summary_prompt | llm
add_routes(app, summary_chain, path="/summary")
# 情感分析服务
sentiment_prompt = ChatPromptTemplate.from_template("分析情感(积极/消极/中性):{text}")
sentiment_chain = sentiment_prompt | llm
add_routes(app, sentiment_chain, path="/sentiment")
# 4. 使用路由组
from fastapi import APIRouter
# 创建路由组
api_router = APIRouter(prefix="/api/v1")
# 添加到路由组
add_routes(api_router, chain, path="/chat")
# 注册路由组到应用
app.include_router(api_router)
# 最终路径:/api/v1/chat/invoke
# 5. 测试路由
import requests
# 测试invoke端点
response = requests.post(
"http://localhost:8000/chat/invoke",
json={"input": "你好"}
)
print(response.json())
# 输出:{"output": {"content": "你好!有什么我可以帮助你的吗?"}}
# 测试stream端点
with requests.post(
"http://localhost:8000/chat/stream",
json={"input": "讲个故事"},
stream=True
) as r:
for chunk in r.iter_content(chunk_size=None):
print(chunk.decode(), end="")
# 6. 访问Playground
# 在浏览器打开:http://localhost:8000/chat/playground
# 可以交互式测试Chain
# 7. 查看Schema
schema_response = requests.get("http://localhost:8000/chat/input_schema")
print(schema_response.json())
# 显示输入数据的JSON Schema
# 8. 信创环境示例
# 使用Ollama本地模型
from langchain.llms import Ollama
local_llm = Ollama(
model="qwen:7b",
base_url="http://localhost:11434"
)
local_prompt = ChatPromptTemplate.from_template("问题:{question}")
local_chain = local_prompt | local_llm
add_routes(app, local_chain, path="/local_chat")
print("✓ 本地模型路由已添加:/local_chat")
---
b.路由规则
a.功能说明
LangServe的路由遵循RESTful设计原则,规范统一。path参数定义基础路径,自动生成子端点。支持路径参数和查询参数。理解路由规则有助于设计清晰的API结构。
b.代码示例
---
from fastapi import FastAPI
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
app = FastAPI()
llm = ChatOpenAI()
# 1. 路径命名规则
# 推荐使用小写字母和下划线
add_routes(app, llm, path="/chat") # ✓ 好
add_routes(app, llm, path="/text_analysis") # ✓ 好
add_routes(app, llm, path="/qa_service") # ✓ 好
# 避免
# add_routes(app, llm, path="/Chat") # ✗ 不推荐大写
# add_routes(app, llm, path="/text-analysis") # ✗ 不推荐连字符
# 2. 嵌套路径
from fastapi import APIRouter
# 版本化API
v1_router = APIRouter(prefix="/api/v1")
add_routes(v1_router, llm, path="/chat")
app.include_router(v1_router)
# 完整路径:/api/v1/chat/invoke
# 按功能分组
ai_router = APIRouter(prefix="/ai")
add_routes(ai_router, llm, path="/chat")
add_routes(ai_router, llm, path="/translate")
app.include_router(ai_router)
# 路径:/ai/chat/invoke, /ai/translate/invoke
# 3. 路由优先级
# FastAPI按定义顺序匹配路由
# 具体路径优先
@app.get("/users/me")
def get_current_user():
return {"user": "current"}
@app.get("/users/{user_id}")
def get_user(user_id: int):
return {"user_id": user_id}
# /users/me 会匹配第一个,不会被第二个捕获
# 4. 动态路由
# 虽然add_routes不直接支持路径参数,
# 但可以通过包装实现
from langchain.schema.runnable import RunnableLambda
def create_chat_handler(model_name: str):
"""为不同模型创建处理器"""
llm = ChatOpenAI(model=model_name)
return llm
# 为不同模型创建路由
for model in ["gpt-3.5-turbo", "gpt-4"]:
model_llm = create_chat_handler(model)
path = f"/chat/{model.replace('.', '_')}"
add_routes(app, model_llm, path=path)
# 生成路由:
# /chat/gpt-3_5-turbo/invoke
# /chat/gpt-4/invoke
# 5. 路由冲突处理
from fastapi import HTTPException
# 自定义端点避免冲突
@app.post("/custom/invoke")
def custom_invoke(data: dict):
"""自定义invoke逻辑"""
if "input" not in data:
raise HTTPException(status_code=400, detail="缺少input字段")
return {"output": f"处理:{data['input']}"}
# add_routes使用不同路径
add_routes(app, llm, path="/langserve_chat")
# 6. 路由文档
from fastapi import APIRouter
# 为路由组添加文档
documented_router = APIRouter(
prefix="/api",
tags=["AI Services"],
responses={
404: {"description": "未找到"},
500: {"description": "服务器错误"}
}
)
add_routes(documented_router, llm, path="/chat")
app.include_router(documented_router)
# 7. 完整路由结构示例
app = FastAPI(title="LangServe API")
# 公开API
public_router = APIRouter(prefix="/public", tags=["Public"])
add_routes(public_router, llm, path="/chat")
# 内部API(需要认证)
from fastapi import Depends, Header, HTTPException
def verify_api_key(x_api_key: str = Header(...)):
if x_api_key != "secret":
raise HTTPException(status_code=401)
return x_api_key
internal_router = APIRouter(
prefix="/internal",
tags=["Internal"],
dependencies=[Depends(verify_api_key)]
)
add_routes(internal_router, llm, path="/admin_chat")
# 注册所有路由
app.include_router(public_router)
app.include_router(internal_router)
# 最终结构:
# /public/chat/invoke - 公开访问
# /internal/admin_chat/invoke - 需要API Key
# 8. 信创环境路由配置
# 分离本地模型和云模型
local_router = APIRouter(prefix="/local", tags=["本地模型"])
cloud_router = APIRouter(prefix="/cloud", tags=["云模型"])
# 本地Ollama
from langchain.llms import Ollama
local_llm = Ollama(model="qwen:7b")
add_routes(local_router, local_llm, path="/chat")
# 云端OpenAI(如果可用)
cloud_llm = ChatOpenAI(model="gpt-3.5-turbo")
add_routes(cloud_router, cloud_llm, path="/chat")
app.include_router(local_router)
app.include_router(cloud_router)
# 路径:
# /local/chat/invoke - 使用本地模型
# /cloud/chat/invoke - 使用云端模型
print("✓ 路由配置完成")
print(f"总路由数:{len(app.routes)}")
---
02.端点类型
a.invoke端点
a.功能说明
invoke端点用于单次同步调用,发送一个输入,返回一个输出。适用于简单的请求-响应场景,如单条消息的问答。支持配置参数,如temperature、max_tokens等。invoke是最常用的端点类型。
b.代码示例
---
# 1. invoke端点请求格式
# POST /chat/invoke
{
"input": "你好", # 必需:输入数据
"config": { # 可选:配置参数
"configurable": {
"temperature": 0.7
},
"tags": ["production"],
"metadata": {"user_id": "123"}
}
}
# 2. invoke端点响应格式
{
"output": { # 输出数据
"content": "你好!有什么可以帮助你的吗?"
},
"metadata": { # 元数据(可选)
"run_id": "abc123",
"tokens": 15
}
}
# 3. Python客户端调用
import requests
response = requests.post(
"http://localhost:8000/chat/invoke",
json={"input": "什么是LangServe?"}
)
result = response.json()
print(result["output"]["content"])
# 4. 带配置的调用
response = requests.post(
"http://localhost:8000/chat/invoke",
json={
"input": "讲个笑话",
"config": {
"configurable": {
"temperature": 0.9, # 更随机
"max_tokens": 100
}
}
}
)
# 5. 错误处理
try:
response = requests.post(
"http://localhost:8000/chat/invoke",
json={"input": "测试"},
timeout=30
)
response.raise_for_status()
result = response.json()
except requests.exceptions.Timeout:
print("请求超时")
except requests.exceptions.HTTPError as e:
print(f"HTTP错误:{e.response.status_code}")
print(e.response.json())
# 6. 异步调用
import httpx
import asyncio
async def async_invoke():
async with httpx.AsyncClient() as client:
response = await client.post(
"http://localhost:8000/chat/invoke",
json={"input": "异步测试"}
)
return response.json()
result = asyncio.run(async_invoke())
print(result)
# 7. 批量invoke(手动循环)
questions = ["问题1", "问题2", "问题3"]
for q in questions:
response = requests.post(
"http://localhost:8000/chat/invoke",
json={"input": q}
)
print(f"Q: {q}")
print(f"A: {response.json()['output']['content']}")
# 8. 信创环境示例
# 使用国产浏览器请求
import requests
# 配置代理(如果需要)
proxies = {
"http": "http://proxy.company.com:8080",
"https": "https://proxy.company.com:8080"
}
response = requests.post(
"http://internal-server:8000/local/chat/invoke",
json={"input": "测试信创环境"},
proxies=proxies,
verify=False # 如果使用自签名证书
)
print(response.json())
---
b.batch端点
a.功能说明
batch端点用于批量处理多个输入,提高效率。一次请求发送多个输入,返回多个输出。适用于批量数据处理场景,如批量翻译、批量分类。支持并发处理,性能优于多次invoke。
b.代码示例
---
# 1. batch端点请求格式
# POST /chat/batch
{
"inputs": [ # 多个输入
"你好",
"再见",
"谢谢"
],
"config": { # 全局配置(可选)
"max_concurrency": 5 # 并发数
}
}
# 2. batch端点响应格式
{
"outputs": [ # 多个输出,顺序对应输入
{"content": "你好!有什么可以帮助你的吗?"},
{"content": "再见!祝你愉快!"},
{"content": "不客气!"}
]
}
# 3. Python客户端调用
import requests
inputs = [
"什么是AI?",
"什么是机器学习?",
"什么是深度学习?"
]
response = requests.post(
"http://localhost:8000/chat/batch",
json={"inputs": inputs}
)
outputs = response.json()["outputs"]
for i, output in enumerate(outputs):
print(f"Q{i+1}: {inputs[i]}")
print(f"A{i+1}: {output['content']}\n")
# 4. 批量翻译示例
texts = [
"Hello",
"Good morning",
"Thank you"
]
response = requests.post(
"http://localhost:8000/translate/batch",
json={
"inputs": [
{"text": t, "lang": "中文"}
for t in texts
]
}
)
translations = response.json()["outputs"]
for orig, trans in zip(texts, translations):
print(f"{orig} -> {trans['content']}")
# 5. 控制并发
response = requests.post(
"http://localhost:8000/chat/batch",
json={
"inputs": [f"问题{i}" for i in range(100)],
"config": {
"max_concurrency": 10 # 最多10个并发
}
}
)
# 6. 异步批量处理
import httpx
import asyncio
async def async_batch(inputs):
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
"http://localhost:8000/chat/batch",
json={"inputs": inputs}
)
return response.json()["outputs"]
inputs = [f"异步问题{i}" for i in range(20)]
results = asyncio.run(async_batch(inputs))
# 7. 错误处理
response = requests.post(
"http://localhost:8000/chat/batch",
json={"inputs": ["正常输入", None, "另一个输入"]} # 包含无效输入
)
outputs = response.json()["outputs"]
for i, output in enumerate(outputs):
if "error" in output:
print(f"输入{i}失败:{output['error']}")
else:
print(f"输入{i}成功:{output['content']}")
# 8. 性能对比
import time
questions = [f"问题{i}" for i in range(10)]
# 方式1:多次invoke
start = time.time()
for q in questions:
requests.post(
"http://localhost:8000/chat/invoke",
json={"input": q}
)
invoke_time = time.time() - start
print(f"Invoke方式耗时:{invoke_time:.2f}秒")
# 方式2:batch
start = time.time()
requests.post(
"http://localhost:8000/chat/batch",
json={"inputs": questions}
)
batch_time = time.time() - start
print(f"Batch方式耗时:{batch_time:.2f}秒")
print(f"性能提升:{(invoke_time/batch_time):.2f}倍")
# 9. 信创环境批量处理
# 批量处理本地文档
documents = [
"文档内容1...",
"文档内容2...",
"文档内容3..."
]
response = requests.post(
"http://localhost:8000/local/summary/batch",
json={
"inputs": [{"content": doc} for doc in documents],
"config": {"max_concurrency": 3}
}
)
summaries = response.json()["outputs"]
for i, summary in enumerate(summaries):
print(f"文档{i+1}摘要:{summary['content']}")
---
c.stream端点
a.功能说明
stream端点用于流式输出,边生成边返回。适用于长文本生成场景,如聊天、写作助手。提升用户体验,无需等待完整结果。使用Server-Sent Events(SSE)协议传输。
b.代码示例
---
# 1. stream端点请求格式
# POST /chat/stream
{
"input": "讲一个长故事",
"config": {} # 可选配置
}
# 2. stream端点响应格式(SSE)
# 响应头:Content-Type: text/event-stream
# 数据格式:
"""
event: data
data: {"content": "很"}
event: data
data: {"content": "久"}
event: data
data: {"content": "以前"}
event: end
"""
# 3. Python客户端(requests)
import requests
import json
with requests.post(
"http://localhost:8000/chat/stream",
json={"input": "写一首诗"},
stream=True # 关键:开启流式
) as response:
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data = json.loads(line[6:]) # 去掉"data: "前缀
if "content" in data:
print(data["content"], end="", flush=True)
print() # 换行
# 4. Python客户端(httpx + async)
import httpx
import asyncio
async def stream_chat(question: str):
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
"http://localhost:8000/chat/stream",
json={"input": question}
) as response:
async for line in response.aiter_lines():
if line.startswith('data: '):
data = eval(line[6:]) # 简单解析
if "content" in data:
print(data["content"], end="", flush=True)
asyncio.run(stream_chat("解释量子计算"))
# 5. JavaScript客户端
"""
// 使用Fetch API
const response = await fetch('http://localhost:8000/chat/stream', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({input: '你好'})
});
const reader = response.body.getReader();
const decoder = new TextDecoder();
while (true) {
const {done, value} = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
const data = JSON.parse(line.slice(6));
console.log(data.content);
}
}
}
"""
# 6. 完整流式示例
import requests
import json
import sys
def stream_response(url, input_data):
"""流式获取响应"""
with requests.post(
url,
json={"input": input_data},
stream=True,
timeout=60
) as response:
response.raise_for_status()
full_text = ""
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
try:
data = json.loads(line_str[6:])
if "content" in data:
content = data["content"]
print(content, end="", flush=True)
full_text += content
elif line_str == "event: end":
break
except json.JSONDecodeError:
continue
print() # 换行
return full_text
# 使用
result = stream_response(
"http://localhost:8000/chat/stream",
"解释相对论"
)
# 7. 错误处理
try:
with requests.post(
"http://localhost:8000/chat/stream",
json={"input": "测试"},
stream=True,
timeout=30
) as response:
response.raise_for_status()
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if 'error' in line_str.lower():
print(f"错误:{line_str}")
break
print(line_str)
except requests.exceptions.Timeout:
print("流式请求超时")
except requests.exceptions.RequestException as e:
print(f"请求失败:{e}")
# 8. 信创环境示例
# 在麒麟系统上使用流式输出
def local_stream_chat(question: str):
"""本地模型流式聊天"""
import requests
url = "http://localhost:8000/local/chat/stream"
print(f"问:{question}")
print("答:", end="")
with requests.post(
url,
json={"input": question},
stream=True,
verify=False # 内网环境
) as response:
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
try:
data = eval(line_str[6:])
if "content" in data:
print(data["content"], end="", flush=True)
except:
pass
print("\n")
# 测试
local_stream_chat("介绍一下信创技术")
---
2 服务端开发
2.1 Runnable部署
01.Chain转Runnable
a.LCEL语法
a.功能说明
LCEL(LangChain Expression Language)是LangChain的链式表达式语法,使用管道操作符(|)连接组件。每个组件都实现Runnable接口,天然支持LangServe部署。LCEL提供简洁的链构建方式,代码可读性强。支持复杂的数据流转换和条件分支。
b.代码示例
---
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
# 1. 基础LCEL
llm = ChatOpenAI(model="gpt-3.5-turbo")
prompt = ChatPromptTemplate.from_template("翻译为英文:{text}")
# 使用 | 连接
chain = prompt | llm | StrOutputParser()
# 这个chain是Runnable,可以直接部署
from langserve import add_routes
add_routes(app, chain, path="/translate")
# 2. 多步骤LCEL
from langchain.schema.runnable import RunnablePassthrough
# 步骤1:提取关键词
keyword_prompt = ChatPromptTemplate.from_template("提取关键词:{text}")
keyword_chain = keyword_prompt | llm | StrOutputParser()
# 步骤2:基于关键词生成摘要
summary_prompt = ChatPromptTemplate.from_template(
"基于关键词'{keywords}',总结原文:{text}"
)
# 组合
full_chain = {
"keywords": keyword_chain,
"text": RunnablePassthrough()
} | summary_prompt | llm | StrOutputParser()
add_routes(app, full_chain, path="/smart_summary")
# 3. 条件分支
from langchain.schema.runnable import RunnableBranch
# 根据输入长度选择不同处理
def get_length_category(inputs):
length = len(inputs["text"])
if length < 100:
return "short"
elif length < 500:
return "medium"
else:
return "long"
short_prompt = ChatPromptTemplate.from_template("简短总结:{text}")
medium_prompt = ChatPromptTemplate.from_template("中等总结:{text}")
long_prompt = ChatPromptTemplate.from_template("详细总结:{text}")
branch_chain = RunnableBranch(
(lambda x: get_length_category(x) == "short", short_prompt | llm),
(lambda x: get_length_category(x) == "medium", medium_prompt | llm),
long_prompt | llm # 默认
)
add_routes(app, branch_chain, path="/adaptive_summary")
# 4. 并行执行
from langchain.schema.runnable import RunnableParallel
# 同时执行多个任务
parallel_chain = RunnableParallel(
summary=ChatPromptTemplate.from_template("总结:{text}") | llm,
keywords=ChatPromptTemplate.from_template("关键词:{text}") | llm,
sentiment=ChatPromptTemplate.from_template("情感:{text}") | llm
)
add_routes(app, parallel_chain, path="/analyze")
# 调用返回:
# {
# "summary": "...",
# "keywords": "...",
# "sentiment": "..."
# }
# 5. 数据转换
from langchain.schema.runnable import RunnableLambda
def format_input(data):
"""格式化输入"""
return {"text": data["content"].upper()}
def format_output(result):
"""格式化输出"""
return {
"result": result.content,
"length": len(result.content),
"timestamp": "2024-01-01"
}
transform_chain = (
RunnableLambda(format_input) |
prompt | llm |
RunnableLambda(format_output)
)
add_routes(app, transform_chain, path="/transform")
# 6. 错误处理
from langchain.schema.runnable import RunnableWithFallbacks
# 主模型
primary_llm = ChatOpenAI(model="gpt-4", temperature=0)
# 备用模型
fallback_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
safe_chain = prompt | primary_llm.with_fallbacks([fallback_llm])
add_routes(app, safe_chain, path="/safe_chat")
# 7. 完整示例:智能客服
from langchain.memory import ConversationBufferMemory
# 意图识别
intent_prompt = ChatPromptTemplate.from_template(
"识别意图(咨询/投诉/建议):{question}"
)
intent_chain = intent_prompt | llm | StrOutputParser()
# 知识库检索(简化)
def retrieve_knowledge(inputs):
intent = inputs["intent"]
# 实际应从向量库检索
knowledge = f"关于{intent}的知识..."
return {"knowledge": knowledge, "question": inputs["question"]}
# 生成回复
answer_prompt = ChatPromptTemplate.from_template(
"基于知识'{knowledge}',回答'{question}'"
)
# 组合
customer_service_chain = (
{"intent": intent_chain, "question": RunnablePassthrough()} |
RunnableLambda(retrieve_knowledge) |
answer_prompt | llm | StrOutputParser()
)
add_routes(app, customer_service_chain, path="/customer_service")
# 8. 信创环境LCEL
from langchain.llms import Ollama
# 本地模型
local_llm = Ollama(model="qwen:7b", base_url="http://localhost:11434")
# 简单链
local_prompt = ChatPromptTemplate.from_template("回答:{question}")
local_chain = local_prompt | local_llm | StrOutputParser()
add_routes(app, local_chain, path="/local_qa")
# 混合链(本地+云端)
hybrid_chain = RunnableWithFallbacks(
runnable=prompt | local_llm, # 优先本地
fallbacks=[prompt | ChatOpenAI()] # 失败用云端
)
add_routes(app, hybrid_chain, path="/hybrid_chat")
print("✓ LCEL Chain部署完成")
---
b.自定义Runnable
a.功能说明
继承Runnable基类创建自定义组件,实现特定业务逻辑。实现invoke、batch、stream等方法。自定义Runnable可以集成到LCEL链中,与其他组件无缝组合。适用于复杂业务逻辑或第三方API集成。
b.代码示例
---
from langchain.schema.runnable import Runnable
from typing import Any, List, Iterator
# 1. 基础自定义Runnable
class MyCustomRunnable(Runnable):
"""自定义Runnable示例"""
def invoke(self, input: Any, config: dict = None) -> Any:
"""同步调用"""
return f"处理:{input}"
def batch(self, inputs: List[Any], config: dict = None) -> List[Any]:
"""批量调用"""
return [self.invoke(inp, config) for inp in inputs]
def stream(self, input: Any, config: dict = None) -> Iterator[Any]:
"""流式调用"""
result = self.invoke(input, config)
for char in result:
yield char
# 部署
custom_runnable = MyCustomRunnable()
add_routes(app, custom_runnable, path="/custom")
# 2. 数据库查询Runnable
import dmPython
from pydantic import BaseModel
class QueryInput(BaseModel):
table: str
condition: str
class DatabaseQueryRunnable(Runnable[QueryInput, list]):
"""数据库查询Runnable"""
def __init__(self, connection_string: str):
self.conn = dmPython.connect(connection_string)
def invoke(self, input: QueryInput, config: dict = None) -> list:
"""执行查询"""
cursor = self.conn.cursor()
query = f"SELECT * FROM {input.table} WHERE {input.condition}"
cursor.execute(query)
results = cursor.fetchall()
return [dict(row) for row in results]
# 部署
db_runnable = DatabaseQueryRunnable(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236"
)
add_routes(app, db_runnable, path="/db_query")
# 3. API调用Runnable
import requests
class APICallRunnable(Runnable[dict, dict]):
"""第三方API调用Runnable"""
def __init__(self, api_url: str, api_key: str):
self.api_url = api_url
self.api_key = api_key
def invoke(self, input: dict, config: dict = None) -> dict:
"""调用API"""
response = requests.post(
self.api_url,
json=input,
headers={"Authorization": f"Bearer {self.api_key}"}
)
return response.json()
api_runnable = APICallRunnable(
"https://api.example.com/process",
"your_api_key"
)
add_routes(app, api_runnable, path="/external_api")
# 4. 文件处理Runnable
class FileProcessRunnable(Runnable[str, str]):
"""文件处理Runnable"""
def invoke(self, input: str, config: dict = None) -> str:
"""处理文件"""
# input是文件路径
with open(input, 'r', encoding='utf-8') as f:
content = f.read()
# 处理内容(示例:统计字数)
word_count = len(content.split())
return f"文件包含{word_count}个单词"
file_runnable = FileProcessRunnable()
add_routes(app, file_runnable, path="/file_process")
# 5. 缓存Runnable
from functools import lru_cache
class CachedRunnable(Runnable[str, str]):
"""带缓存的Runnable"""
@lru_cache(maxsize=100)
def _cached_invoke(self, input: str) -> str:
"""缓存的调用"""
# 模拟耗时操作
import time
time.sleep(2)
return f"结果:{input}"
def invoke(self, input: str, config: dict = None) -> str:
return self._cached_invoke(input)
cached_runnable = CachedRunnable()
add_routes(app, cached_runnable, path="/cached")
# 6. 组合自定义Runnable
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
# 自定义预处理
class PreprocessRunnable(Runnable[str, str]):
def invoke(self, input: str, config: dict = None) -> str:
# 清洗文本
cleaned = input.strip().lower()
return cleaned
# 自定义后处理
class PostprocessRunnable(Runnable[Any, dict]):
def invoke(self, input: Any, config: dict = None) -> dict:
return {
"result": input.content,
"length": len(input.content),
"processed": True
}
# 组合
llm = ChatOpenAI()
prompt = ChatPromptTemplate.from_template("处理:{text}")
full_chain = (
PreprocessRunnable() |
{"text": lambda x: x} |
prompt | llm |
PostprocessRunnable()
)
add_routes(app, full_chain, path="/pipeline")
# 7. 异步Runnable
import asyncio
class AsyncRunnable(Runnable[str, str]):
"""异步Runnable"""
async def ainvoke(self, input: str, config: dict = None) -> str:
"""异步调用"""
await asyncio.sleep(1) # 模拟异步操作
return f"异步结果:{input}"
def invoke(self, input: str, config: dict = None) -> str:
"""同步调用(内部用异步)"""
return asyncio.run(self.ainvoke(input, config))
async_runnable = AsyncRunnable()
add_routes(app, async_runnable, path="/async")
# 8. 信创环境自定义Runnable
class XinchuangRunnable(Runnable[dict, dict]):
"""信创环境适配Runnable"""
def __init__(self):
# 连接达梦数据库
self.db = dmPython.connect(
user="SYSDBA",
password="SYSDBA",
server="localhost",
port=5236
)
# 连接本地LLM
from langchain.llms import Ollama
self.llm = Ollama(model="qwen:7b")
def invoke(self, input: dict, config: dict = None) -> dict:
"""处理请求"""
# 1. 从数据库查询
cursor = self.db.cursor()
cursor.execute(f"SELECT * FROM docs WHERE id = ?", (input["doc_id"],))
doc = cursor.fetchone()
# 2. 使用本地LLM处理
result = self.llm.invoke(f"分析文档:{doc[1]}")
# 3. 返回结果
return {
"doc_id": input["doc_id"],
"analysis": result,
"source": "xinchuang"
}
xc_runnable = XinchuangRunnable()
add_routes(app, xc_runnable, path="/xinchuang")
print("✓ 自定义Runnable部署完成")
---
02.输入输出定义
a.Pydantic模型
a.功能说明
使用Pydantic定义输入输出的数据结构,实现类型验证和文档生成。FastAPI自动根据Pydantic模型生成OpenAPI Schema。提供字段验证、默认值、描述等功能。清晰的输入输出定义提升API可用性。
b.代码示例
---
from pydantic import BaseModel, Field
from typing import Optional, List
from langserve import add_routes
# 1. 基础输入输出模型
class ChatInput(BaseModel):
"""聊天输入"""
message: str = Field(..., description="用户消息")
temperature: Optional[float] = Field(0.7, description="温度参数", ge=0, le=1)
class ChatOutput(BaseModel):
"""聊天输出"""
response: str = Field(..., description="AI回复")
tokens: int = Field(..., description="使用的Token数")
# 2. 在Chain中使用
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnableLambda
llm = ChatOpenAI()
def process_input(data: ChatInput) -> str:
"""处理输入"""
return data.message
def process_output(result) -> ChatOutput:
"""处理输出"""
return ChatOutput(
response=result.content,
tokens=result.response_metadata.get("token_usage", {}).get("total_tokens", 0)
)
chain = (
RunnableLambda(process_input) |
llm |
RunnableLambda(process_output)
)
add_routes(
app,
chain,
path="/typed_chat",
input_type=ChatInput,
output_type=ChatOutput
)
# 3. 复杂输入模型
class TranslateInput(BaseModel):
"""翻译输入"""
text: str = Field(..., description="要翻译的文本", min_length=1, max_length=5000)
source_lang: str = Field("auto", description="源语言", pattern="^[a-z]{2}$")
target_lang: str = Field(..., description="目标语言", pattern="^[a-z]{2}$")
formality: Optional[str] = Field("default", description="正式程度", regex="^(default|formal|informal)$")
class TranslateOutput(BaseModel):
"""翻译输出"""
translated_text: str = Field(..., description="翻译后的文本")
detected_lang: str = Field(..., description="检测到的源语言")
confidence: float = Field(..., description="置信度", ge=0, le=1)
# 4. 嵌套模型
class Address(BaseModel):
"""地址"""
street: str
city: str
country: str
class User(BaseModel):
"""用户"""
name: str
age: int = Field(..., ge=0, le=150)
email: str = Field(..., regex=r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$")
address: Address
class UserAnalysisInput(BaseModel):
"""用户分析输入"""
user: User
analysis_type: str = Field(..., description="分析类型")
# 5. 列表类型
class BatchTranslateInput(BaseModel):
"""批量翻译输入"""
texts: List[str] = Field(..., description="文本列表", min_items=1, max_items=100)
target_lang: str = Field(..., description="目标语言")
class BatchTranslateOutput(BaseModel):
"""批量翻译输出"""
translations: List[str] = Field(..., description="翻译结果列表")
total_count: int = Field(..., description="总数")
# 6. 可选字段和默认值
class SummaryInput(BaseModel):
"""摘要输入"""
content: str = Field(..., description="内容")
max_length: int = Field(100, description="最大长度", ge=10, le=1000)
style: str = Field("concise", description="风格")
include_keywords: bool = Field(False, description="是否包含关键词")
language: Optional[str] = Field(None, description="语言")
# 7. 自定义验证
from pydantic import validator
class QuestionInput(BaseModel):
"""问题输入"""
question: str = Field(..., description="问题")
context: Optional[str] = Field(None, description="上下文")
@validator('question')
def question_not_empty(cls, v):
"""验证问题非空"""
if not v.strip():
raise ValueError('问题不能为空')
return v.strip()
@validator('context')
def context_length(cls, v):
"""验证上下文长度"""
if v and len(v) > 10000:
raise ValueError('上下文过长(最多10000字符)')
return v
# 8. 配置Pydantic
class ConfiguredInput(BaseModel):
"""配置的输入"""
data: str
class Config:
# 允许额外字段
extra = "allow"
# 使用枚举值
use_enum_values = True
# 验证赋值
validate_assignment = True
# 9. 枚举类型
from enum import Enum
class AnalysisType(str, Enum):
"""分析类型"""
SENTIMENT = "sentiment"
KEYWORDS = "keywords"
SUMMARY = "summary"
class AnalysisInput(BaseModel):
"""分析输入"""
text: str = Field(..., description="文本")
type: AnalysisType = Field(..., description="分析类型")
# 10. 信创环境示例
class XinchuangDocInput(BaseModel):
"""信创文档输入"""
doc_id: str = Field(..., description="文档ID")
db_type: str = Field("dameng", description="数据库类型", regex="^(dameng|postgresql)$")
llm_type: str = Field("ollama", description="LLM类型", regex="^(ollama|openai)$")
class Config:
schema_extra = {
"example": {
"doc_id": "DOC001",
"db_type": "dameng",
"llm_type": "ollama"
}
}
class XinchuangDocOutput(BaseModel):
"""信创文档输出"""
doc_id: str
content: str
analysis: str
source: str = Field(..., description="数据来源")
processed_at: str = Field(..., description="处理时间")
print("✓ Pydantic模型定义完成")
---
2.2 路由配置
01.路径参数
a.基础路径
a.功能说明
配置API的基础路径,组织端点结构。使用前缀分组相关功能,如/api/v1、/chat等。路径应简洁、语义清晰、符合RESTful规范。合理的路径设计提升API的可用性和可维护性。
b.代码示例
---
from fastapi import FastAPI, APIRouter
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
app = FastAPI()
llm = ChatOpenAI()
# 1. 基础路径
add_routes(app, llm, path="/chat")
# 端点:/chat/invoke, /chat/stream, etc.
# 2. 版本化路径
v1_router = APIRouter(prefix="/api/v1")
add_routes(v1_router, llm, path="/chat")
app.include_router(v1_router)
# 端点:/api/v1/chat/invoke
v2_router = APIRouter(prefix="/api/v2")
add_routes(v2_router, llm, path="/chat")
app.include_router(v2_router)
# 端点:/api/v2/chat/invoke
# 3. 功能分组
ai_router = APIRouter(prefix="/ai", tags=["AI Services"])
add_routes(ai_router, llm, path="/chat")
add_routes(ai_router, llm, path="/translate")
app.include_router(ai_router)
# 端点:/ai/chat/invoke, /ai/translate/invoke
# 4. 多级路径
api_router = APIRouter(prefix="/api")
v1_subrouter = APIRouter(prefix="/v1")
ai_subrouter = APIRouter(prefix="/ai")
add_routes(ai_subrouter, llm, path="/chat")
v1_subrouter.include_router(ai_subrouter)
api_router.include_router(v1_subrouter)
app.include_router(api_router)
# 端点:/api/v1/ai/chat/invoke
# 5. 路径命名最佳实践
# ✓ 推荐
add_routes(app, llm, path="/chat") # 简洁
add_routes(app, llm, path="/text_analysis") # 下划线
add_routes(app, llm, path="/qa_service") # 语义清晰
# ✗ 避免
# add_routes(app, llm, path="/Chat") # 大写
# add_routes(app, llm, path="/text-analysis") # 连字符
# add_routes(app, llm, path="/qa/service") # 不必要的嵌套
# 6. 环境区分
import os
env = os.getenv("ENV", "dev")
if env == "dev":
prefix = "/dev"
elif env == "staging":
prefix = "/staging"
else:
prefix = ""
router = APIRouter(prefix=prefix)
add_routes(router, llm, path="/chat")
app.include_router(router)
# 开发:/dev/chat/invoke
# 生产:/chat/invoke
# 7. 完整路径示例
app = FastAPI(title="LangServe API")
# 公开API
public_router = APIRouter(prefix="/public", tags=["Public"])
add_routes(public_router, llm, path="/chat")
# 内部API
internal_router = APIRouter(prefix="/internal", tags=["Internal"])
add_routes(internal_router, llm, path="/admin_chat")
# 注册
app.include_router(public_router)
app.include_router(internal_router)
print("路由结构:")
print(" /public/chat/invoke")
print(" /internal/admin_chat/invoke")
# 8. 信创环境路径配置
# 本地模型路径
local_router = APIRouter(prefix="/local", tags=["本地模型"])
from langchain.llms import Ollama
local_llm = Ollama(model="qwen:7b")
add_routes(local_router, local_llm, path="/chat")
# 云端模型路径
cloud_router = APIRouter(prefix="/cloud", tags=["云端模型"])
cloud_llm = ChatOpenAI()
add_routes(cloud_router, cloud_llm, path="/chat")
app.include_router(local_router)
app.include_router(cloud_router)
print("✓ 信创路径配置完成")
print(" 本地:/local/chat/invoke")
print(" 云端:/cloud/chat/invoke")
---
b.路径前缀
a.功能说明
使用APIRouter的prefix参数为多个端点添加统一前缀。简化路由管理,避免重复代码。支持嵌套前缀,构建多级路径结构。前缀是组织大型API的关键机制。
b.代码示例
---
from fastapi import APIRouter
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
# 1. 单级前缀
api_router = APIRouter(prefix="/api")
llm = ChatOpenAI()
add_routes(api_router, llm, path="/chat")
add_routes(api_router, llm, path="/translate")
app.include_router(api_router)
# 端点:/api/chat/invoke, /api/translate/invoke
# 2. 版本前缀
v1 = APIRouter(prefix="/api/v1", tags=["V1"])
v2 = APIRouter(prefix="/api/v2", tags=["V2"])
# V1使用旧模型
v1_llm = ChatOpenAI(model="gpt-3.5-turbo")
add_routes(v1, v1_llm, path="/chat")
# V2使用新模型
v2_llm = ChatOpenAI(model="gpt-4")
add_routes(v2, v2_llm, path="/chat")
app.include_router(v1)
app.include_router(v2)
# 端点:/api/v1/chat/invoke, /api/v2/chat/invoke
# 3. 功能模块前缀
# AI模块
ai_router = APIRouter(prefix="/ai", tags=["AI"])
add_routes(ai_router, llm, path="/chat")
add_routes(ai_router, llm, path="/translate")
# 工具模块
tools_router = APIRouter(prefix="/tools", tags=["Tools"])
@tools_router.post("/convert")
def convert_format(data: dict):
return {"converted": data}
app.include_router(ai_router)
app.include_router(tools_router)
# 端点:/ai/chat/invoke, /tools/convert
# 4. 嵌套前缀
root = APIRouter(prefix="/api")
v1 = APIRouter(prefix="/v1")
services = APIRouter(prefix="/services")
add_routes(services, llm, path="/chat")
v1.include_router(services)
root.include_router(v1)
app.include_router(root)
# 端点:/api/v1/services/chat/invoke
# 5. 条件前缀
import os
def get_prefix():
"""根据环境返回前缀"""
env = os.getenv("ENVIRONMENT", "production")
if env == "development":
return "/dev"
elif env == "staging":
return "/stage"
return ""
dynamic_router = APIRouter(prefix=get_prefix())
add_routes(dynamic_router, llm, path="/chat")
app.include_router(dynamic_router)
# 6. 多租户前缀
def create_tenant_router(tenant_id: str):
"""为租户创建路由"""
router = APIRouter(prefix=f"/tenant/{tenant_id}")
# 租户专属LLM配置
tenant_llm = ChatOpenAI(
api_key=get_tenant_api_key(tenant_id)
)
add_routes(router, tenant_llm, path="/chat")
return router
# 为多个租户创建路由
tenants = ["tenant_a", "tenant_b", "tenant_c"]
for tenant in tenants:
app.include_router(create_tenant_router(tenant))
# 端点:
# /tenant/tenant_a/chat/invoke
# /tenant/tenant_b/chat/invoke
# /tenant/tenant_c/chat/invoke
# 7. 地域前缀
# 亚洲区
asia_router = APIRouter(prefix="/asia")
asia_llm = ChatOpenAI(base_url="https://api.asia.openai.com")
add_routes(asia_router, asia_llm, path="/chat")
# 欧洲区
europe_router = APIRouter(prefix="/europe")
europe_llm = ChatOpenAI(base_url="https://api.europe.openai.com")
add_routes(europe_router, europe_llm, path="/chat")
app.include_router(asia_router)
app.include_router(europe_router)
# 8. 信创环境前缀
# 不同信创组件前缀
xc_router = APIRouter(prefix="/xinchuang", tags=["信创"])
# 达梦数据库相关
dm_router = APIRouter(prefix="/dameng")
add_routes(dm_router, llm, path="/query")
# Ollama本地模型
ollama_router = APIRouter(prefix="/ollama")
from langchain.llms import Ollama
local_llm = Ollama(model="qwen:7b")
add_routes(ollama_router, local_llm, path="/chat")
# 注册信创子路由
xc_router.include_router(dm_router)
xc_router.include_router(ollama_router)
app.include_router(xc_router)
# 最终端点:
# /xinchuang/dameng/query/invoke
# /xinchuang/ollama/chat/invoke
print("✓ 路径前缀配置完成")
---
02.参数传递
a.查询参数
a.功能说明
通过URL查询参数传递配置,灵活调整Chain行为。适用于简单的配置选项,如temperature、max_tokens。查询参数易于测试和调试。需要在端点处理器中解析和应用参数。
b.代码示例
---
from fastapi import FastAPI, Query
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnableLambda
app = FastAPI()
# 1. 自定义端点带查询参数
@app.post("/custom_chat")
def custom_chat(
message: str,
temperature: float = Query(0.7, ge=0, le=1, description="温度参数"),
max_tokens: int = Query(100, ge=1, le=2000, description="最大Token数")
):
"""自定义聊天端点"""
llm = ChatOpenAI(
model="gpt-3.5-turbo",
temperature=temperature,
max_tokens=max_tokens
)
result = llm.invoke(message)
return {"response": result.content}
# 调用:POST /custom_chat?message=你好&temperature=0.9&max_tokens=50
# 2. 条件路由基于查询参数
@app.post("/smart_chat")
def smart_chat(
message: str,
model: str = Query("gpt-3.5-turbo", description="模型名称")
):
"""根据参数选择模型"""
llm = ChatOpenAI(model=model)
result = llm.invoke(message)
return {"response": result.content, "model": model}
# 调用:POST /smart_chat?message=测试&model=gpt-4
# 3. 配置传递到Chain
def create_configurable_chain(temperature: float = 0.7):
"""创建可配置的Chain"""
llm = ChatOpenAI(temperature=temperature)
prompt = ChatPromptTemplate.from_template("回答:{question}")
return prompt | llm
@app.post("/configurable")
def configurable_endpoint(
question: str,
temperature: float = Query(0.7, ge=0, le=1)
):
chain = create_configurable_chain(temperature)
result = chain.invoke({"question": question})
return {"answer": result.content}
# 4. 多个查询参数
@app.post("/translate")
def translate_text(
text: str,
source_lang: str = Query("en", description="源语言"),
target_lang: str = Query("zh", description="目标语言"),
formality: str = Query("default", regex="^(default|formal|informal)$")
):
"""翻译文本"""
prompt = ChatPromptTemplate.from_template(
f"将以下{source_lang}文本翻译为{target_lang}(风格:{formality}):{{text}}"
)
llm = ChatOpenAI()
chain = prompt | llm
result = chain.invoke({"text": text})
return {
"translated": result.content,
"source_lang": source_lang,
"target_lang": target_lang
}
# 调用:POST /translate?text=hello&source_lang=en&target_lang=zh&formality=formal
# 5. 可选查询参数
@app.post("/analyze")
def analyze_text(
text: str,
include_sentiment: bool = Query(False, description="包含情感分析"),
include_keywords: bool = Query(False, description="包含关键词提取"),
max_keywords: int = Query(5, ge=1, le=20, description="最大关键词数")
):
"""文本分析"""
results = {}
if include_sentiment:
# 情感分析
results["sentiment"] = "positive"
if include_keywords:
# 关键词提取
results["keywords"] = ["关键词1", "关键词2"][:max_keywords]
return results
# 调用:POST /analyze?text=测试&include_sentiment=true&include_keywords=true&max_keywords=3
# 6. 验证查询参数
from pydantic import validator
@app.post("/validated")
def validated_endpoint(
text: str,
length: int = Query(..., ge=10, le=1000, description="文本长度限制")
):
"""验证参数"""
if len(text) > length:
return {"error": f"文本超过{length}字符限制"}
return {"text": text, "length": len(text)}
# 7. 查询参数默认值
@app.post("/chat_with_defaults")
def chat_with_defaults(
message: str,
temperature: float = Query(0.7),
model: str = Query("gpt-3.5-turbo"),
stream: bool = Query(False)
):
"""带默认值的端点"""
llm = ChatOpenAI(model=model, temperature=temperature)
result = llm.invoke(message)
return {
"response": result.content,
"config": {
"model": model,
"temperature": temperature,
"stream": stream
}
}
# 8. 信创环境查询参数
@app.post("/xinchuang_query")
def xinchuang_query(
question: str,
use_local: bool = Query(True, description="使用本地模型"),
db_type: str = Query("dameng", regex="^(dameng|postgresql)$")
):
"""信创环境查询"""
# 选择LLM
if use_local:
from langchain.llms import Ollama
llm = Ollama(model="qwen:7b")
else:
llm = ChatOpenAI()
# 选择数据库
if db_type == "dameng":
import dmPython
conn = dmPython.connect("...")
else:
import psycopg2
conn = psycopg2.connect("...")
# 处理查询...
result = llm.invoke(question)
return {
"answer": result,
"llm": "本地" if use_local else "云端",
"db": db_type
}
print("✓ 查询参数配置完成")
---
2.3 输入输出模式
01.Schema定义
a.自动生成
a.功能说明
LangServe根据Runnable的输入输出类型自动生成JSON Schema。Schema用于文档生成、客户端验证、IDE提示。访问/input_schema、/output_schema端点获取Schema。自动生成简化开发,确保文档与实现一致。
b.代码示例
---
from fastapi import FastAPI
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
import requests
app = FastAPI()
llm = ChatOpenAI()
prompt = ChatPromptTemplate.from_template("回答:{question}")
chain = prompt | llm
add_routes(app, chain, path="/qa")
# 1. 获取输入Schema
input_schema = requests.get("http://localhost:8000/qa/input_schema")
print("输入Schema:")
print(input_schema.json())
# 输出示例:
# {
# "type": "object",
# "properties": {
# "question": {"type": "string"}
# },
# "required": ["question"]
# }
# 2. 获取输出Schema
output_schema = requests.get("http://localhost:8000/qa/output_schema")
print("输出Schema:")
print(output_schema.json())
# 3. 获取配置Schema
config_schema = requests.get("http://localhost:8000/qa/config_schema")
print("配置Schema:")
print(config_schema.json())
# 4. 使用Pydantic定义明确Schema
from pydantic import BaseModel, Field
from typing import Optional
class QuestionInput(BaseModel):
"""问题输入"""
question: str = Field(..., description="用户问题", min_length=1)
context: Optional[str] = Field(None, description="上下文信息")
language: str = Field("zh", description="语言", regex="^(zh|en)$")
class AnswerOutput(BaseModel):
"""答案输出"""
answer: str = Field(..., description="AI答案")
confidence: float = Field(..., description="置信度", ge=0, le=1)
sources: list = Field(default_factory=list, description="来源")
# 使用明确类型
from langchain.schema.runnable import RunnableLambda
def process_input(data: QuestionInput) -> str:
return data.question
def process_output(result) -> AnswerOutput:
return AnswerOutput(
answer=result.content,
confidence=0.9,
sources=[]
)
typed_chain = (
RunnableLambda(process_input) |
llm |
RunnableLambda(process_output)
)
add_routes(
app,
typed_chain,
path="/typed_qa",
input_type=QuestionInput,
output_type=AnswerOutput
)
# 获取typed_qa的Schema
typed_input_schema = requests.get("http://localhost:8000/typed_qa/input_schema")
print("类型化输入Schema:")
print(typed_input_schema.json())
# 包含QuestionInput的详细定义
# 5. Schema在OpenAPI文档中
# 访问 http://localhost:8000/docs
# 可以看到自动生成的API文档,包含完整的Schema定义
# 6. 复杂Schema示例
from typing import List, Dict
class ComplexInput(BaseModel):
"""复杂输入"""
queries: List[str] = Field(..., description="查询列表")
filters: Dict[str, str] = Field(default_factory=dict, description="过滤条件")
options: Dict[str, bool] = Field(default_factory=dict, description="选项")
class ComplexOutput(BaseModel):
"""复杂输出"""
results: List[Dict[str, any]] = Field(..., description="结果列表")
total: int = Field(..., description="总数")
metadata: Dict[str, any] = Field(..., description="元数据")
# 7. 嵌套Schema
class Address(BaseModel):
street: str
city: str
country: str
class UserProfile(BaseModel):
name: str
age: int
address: Address # 嵌套
class UserAnalysis(BaseModel):
profile: UserProfile
insights: List[str]
# 8. 信创环境Schema
class XinchuangInput(BaseModel):
"""信创环境输入"""
doc_id: str = Field(..., description="文档ID", pattern="^DOC[0-9]{6}$")
db_source: str = Field("dameng", description="数据源", regex="^(dameng|postgresql)$")
llm_provider: str = Field("ollama", description="LLM提供商", regex="^(ollama|openai)$")
class Config:
schema_extra = {
"example": {
"doc_id": "DOC000001",
"db_source": "dameng",
"llm_provider": "ollama"
}
}
print("✓ Schema自动生成配置完成")
---
b.手动指定
a.功能说明
显式指定输入输出类型,覆盖自动推断。使用input_type和output_type参数设置Pydantic模型。手动指定提供更精确的类型控制和文档。适用于复杂业务逻辑或需要严格验证的场景。
b.代码示例
---
from pydantic import BaseModel, Field, validator
from typing import Optional, List
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
from langchain.schema.runnable import RunnableLambda
# 1. 基础手动指定
class ChatInput(BaseModel):
message: str = Field(..., description="用户消息")
class ChatOutput(BaseModel):
response: str = Field(..., description="AI回复")
llm = ChatOpenAI()
def wrap_input(data: ChatInput) -> str:
return data.message
def wrap_output(result) -> ChatOutput:
return ChatOutput(response=result.content)
chain = (
RunnableLambda(wrap_input) |
llm |
RunnableLambda(wrap_output)
)
add_routes(
app,
chain,
path="/manual_chat",
input_type=ChatInput,
output_type=ChatOutput
)
# 2. 带验证的输入
class ValidatedInput(BaseModel):
text: str = Field(..., min_length=1, max_length=5000)
language: str = Field("zh", regex="^(zh|en|ja)$")
@validator('text')
def text_not_empty(cls, v):
if not v.strip():
raise ValueError('文本不能为空')
return v.strip()
class ValidatedOutput(BaseModel):
result: str
word_count: int
language: str
def process_validated(data: ValidatedInput) -> ValidatedOutput:
# 处理逻辑...
return ValidatedOutput(
result=f"处理:{data.text}",
word_count=len(data.text.split()),
language=data.language
)
validated_chain = RunnableLambda(process_validated)
add_routes(
app,
validated_chain,
path="/validated",
input_type=ValidatedInput,
output_type=ValidatedOutput
)
# 3. 复杂嵌套类型
class DocumentMetadata(BaseModel):
title: str
author: str
created_at: str
class DocumentInput(BaseModel):
content: str = Field(..., description="文档内容")
metadata: DocumentMetadata
tags: List[str] = Field(default_factory=list)
class AnalysisResult(BaseModel):
summary: str
keywords: List[str]
sentiment: str
score: float
class DocumentOutput(BaseModel):
doc_id: str
analysis: AnalysisResult
processed_at: str
def analyze_document(data: DocumentInput) -> DocumentOutput:
# 分析逻辑...
return DocumentOutput(
doc_id="DOC123",
analysis=AnalysisResult(
summary="摘要...",
keywords=["关键词1", "关键词2"],
sentiment="positive",
score=0.85
),
processed_at="2024-01-01T00:00:00"
)
doc_chain = RunnableLambda(analyze_document)
add_routes(
app,
doc_chain,
path="/document_analysis",
input_type=DocumentInput,
output_type=DocumentOutput
)
# 4. 可选字段
class FlexibleInput(BaseModel):
required_field: str = Field(..., description="必需字段")
optional_field: Optional[str] = Field(None, description="可选字段")
default_field: str = Field("default", description="有默认值的字段")
flag: bool = Field(False, description="布尔标志")
class FlexibleOutput(BaseModel):
result: str
has_optional: bool
flags: List[str]
# 5. 枚举类型
from enum import Enum
class TaskType(str, Enum):
TRANSLATE = "translate"
SUMMARIZE = "summarize"
ANALYZE = "analyze"
class TaskInput(BaseModel):
text: str
task_type: TaskType = Field(..., description="任务类型")
class TaskOutput(BaseModel):
result: str
task_type: TaskType
duration_ms: int
def process_task(data: TaskInput) -> TaskOutput:
# 根据任务类型处理...
return TaskOutput(
result=f"完成{data.task_type.value}",
task_type=data.task_type,
duration_ms=100
)
task_chain = RunnableLambda(process_task)
add_routes(
app,
task_chain,
path="/task",
input_type=TaskInput,
output_type=TaskOutput
)
# 6. 列表类型
class BatchInput(BaseModel):
items: List[str] = Field(..., min_items=1, max_items=100)
parallel: bool = Field(False)
class BatchOutput(BaseModel):
results: List[str]
total: int
failed: int
# 7. 字典类型
class ConfigInput(BaseModel):
action: str
params: dict = Field(default_factory=dict, description="参数字典")
class ConfigOutput(BaseModel):
success: bool
data: dict
message: str
# 8. 信创环境手动指定
class XinchuangQueryInput(BaseModel):
"""信创查询输入"""
query_id: str = Field(..., pattern="^QRY[0-9]{8}$")
doc_ids: List[str] = Field(..., min_items=1, max_items=10)
db_config: dict = Field(
default={"type": "dameng", "pool_size": 5},
description="数据库配置"
)
llm_config: dict = Field(
default={"provider": "ollama", "model": "qwen:7b"},
description="LLM配置"
)
@validator('doc_ids')
def validate_doc_ids(cls, v):
"""验证文档ID格式"""
for doc_id in v:
if not doc_id.startswith('DOC'):
raise ValueError(f'无效文档ID:{doc_id}')
return v
class XinchuangQueryOutput(BaseModel):
"""信创查询输出"""
query_id: str
results: List[dict]
total_found: int
execution_time_ms: int
db_source: str
llm_provider: str
class Config:
schema_extra = {
"example": {
"query_id": "QRY20240101",
"results": [{"doc_id": "DOC001", "content": "..."}],
"total_found": 5,
"execution_time_ms": 150,
"db_source": "dameng",
"llm_provider": "ollama"
}
}
def xinchuang_query_handler(data: XinchuangQueryInput) -> XinchuangQueryOutput:
"""信创查询处理器"""
# 连接达梦数据库
import dmPython
conn = dmPython.connect(
user="SYSDBA",
password="SYSDBA",
server="localhost",
port=5236
)
# 使用Ollama
from langchain.llms import Ollama
llm = Ollama(**data.llm_config)
# 查询处理...
results = []
for doc_id in data.doc_ids:
# 从数据库查询文档
cursor = conn.cursor()
cursor.execute("SELECT * FROM docs WHERE id = ?", (doc_id,))
doc = cursor.fetchone()
if doc:
results.append({"doc_id": doc_id, "content": doc[1]})
return XinchuangQueryOutput(
query_id=data.query_id,
results=results,
total_found=len(results),
execution_time_ms=100,
db_source=data.db_config["type"],
llm_provider=data.llm_config["provider"]
)
xc_chain = RunnableLambda(xinchuang_query_handler)
add_routes(
app,
xc_chain,
path="/xinchuang_query",
input_type=XinchuangQueryInput,
output_type=XinchuangQueryOutput
)
print("✓ 手动指定类型配置完成")
---
02.数据转换
a.输入映射
a.功能说明
在数据进入Chain前进行预处理和转换。统一输入格式,适配不同数据源。提取、清洗、规范化数据。输入映射确保Chain收到正确格式的数据。
b.代码示例
---
from langchain.schema.runnable import RunnableLambda
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
# 1. 基础输入映射
def map_input(data: dict) -> dict:
"""映射输入数据"""
return {
"text": data.get("content", ""),
"language": data.get("lang", "zh")
}
llm = ChatOpenAI()
prompt = ChatPromptTemplate.from_template("翻译为英文:{text}")
chain = (
RunnableLambda(map_input) |
prompt | llm
)
add_routes(app, chain, path="/translate_mapped")
# 2. 复杂字段提取
def extract_fields(data: dict) -> dict:
"""提取复杂字段"""
return {
"question": data["query"]["text"],
"context": " ".join(data.get("documents", [])),
"user_id": data["meta"]["user_id"]
}
qa_chain = (
RunnableLambda(extract_fields) |
ChatPromptTemplate.from_template(
"基于上下文'{context}',回答'{question}'"
) | llm
)
# 3. 数据清洗
def clean_input(data: dict) -> dict:
"""清洗输入数据"""
text = data.get("text", "")
# 去除空白
text = text.strip()
# 移除特殊字符
import re
text = re.sub(r'[^\w\s,.!?]', '', text)
# 截断过长文本
max_length = 5000
if len(text) > max_length:
text = text[:max_length]
return {"cleaned_text": text}
clean_chain = (
RunnableLambda(clean_input) |
ChatPromptTemplate.from_template("分析:{cleaned_text}") |
llm
)
# 4. 数据规范化
def normalize_input(data: dict) -> dict:
"""规范化输入"""
return {
"text": data.get("text", "").lower(), # 转小写
"type": data.get("type", "general").upper(), # 转大写
"date": data.get("date", "").replace("/", "-") # 日期格式统一
}
# 5. 多字段组合
def combine_fields(data: dict) -> dict:
"""组合多个字段"""
parts = []
if "title" in data:
parts.append(f"标题:{data['title']}")
if "content" in data:
parts.append(f"内容:{data['content']}")
if "tags" in data:
parts.append(f"标签:{', '.join(data['tags'])}")
return {"combined": "\\n".join(parts)}
combine_chain = (
RunnableLambda(combine_fields) |
ChatPromptTemplate.from_template("摘要:{combined}") |
llm
)
# 6. 条件映射
def conditional_map(data: dict) -> dict:
"""条件映射"""
if data.get("type") == "question":
return {"prompt": f"回答问题:{data['text']}"}
elif data.get("type") == "translate":
return {"prompt": f"翻译:{data['text']}"}
else:
return {"prompt": data["text"]}
# 7. 默认值填充
def fill_defaults(data: dict) -> dict:
"""填充默认值"""
return {
"text": data.get("text", "默认文本"),
"language": data.get("language", "zh"),
"max_length": data.get("max_length", 100),
"temperature": data.get("temperature", 0.7)
}
# 8. 信创环境输入映射
def xinchuang_input_mapper(data: dict) -> dict:
"""信创环境输入映射"""
# 提取文档ID
doc_id = data.get("documentId") or data.get("doc_id") or data.get("id")
# 选择数据源
db_type = data.get("database", "dameng").lower()
if db_type not in ["dameng", "postgresql"]:
db_type = "dameng"
# 选择LLM
llm_provider = data.get("llm", "ollama").lower()
if llm_provider not in ["ollama", "openai"]:
llm_provider = "ollama"
# 提取查询文本
query = data.get("query") or data.get("question") or data.get("text")
return {
"doc_id": doc_id,
"db_type": db_type,
"llm_provider": llm_provider,
"query": query,
"options": data.get("options", {})
}
xc_input_chain = RunnableLambda(xinchuang_input_mapper)
# 组合到完整Chain
def process_xinchuang(mapped_data: dict) -> dict:
"""处理信创请求"""
# 连接数据库
if mapped_data["db_type"] == "dameng":
import dmPython
conn = dmPython.connect("...")
else:
import psycopg2
conn = psycopg2.connect("...")
# 选择LLM
if mapped_data["llm_provider"] == "ollama":
from langchain.llms import Ollama
llm = Ollama(model="qwen:7b")
else:
llm = ChatOpenAI()
# 查询和处理...
return {
"result": "处理结果",
"source": {
"db": mapped_data["db_type"],
"llm": mapped_data["llm_provider"]
}
}
xc_full_chain = (
xc_input_chain |
RunnableLambda(process_xinchuang)
)
add_routes(app, xc_full_chain, path="/xinchuang_process")
print("✓ 输入映射配置完成")
---
b.输出格式化
a.功能说明
在返回给客户端前格式化输出数据。统一响应格式,添加元数据。转换数据类型,增强可读性。输出格式化提供一致的API响应结构。
b.代码示例
---
from langchain.schema.runnable import RunnableLambda
from langchain.chat_models import ChatOpenAI
import time
llm = ChatOpenAI()
# 1. 基础输出格式化
def format_output(result) -> dict:
"""格式化输出"""
return {
"response": result.content,
"status": "success",
"timestamp": time.time()
}
chain = llm | RunnableLambda(format_output)
add_routes(app, chain, path="/formatted_chat")
# 2. 添加元数据
def add_metadata(result) -> dict:
"""添加元数据"""
return {
"data": {
"content": result.content,
"type": "text"
},
"metadata": {
"model": "gpt-3.5-turbo",
"tokens": result.response_metadata.get("token_usage", {}),
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"version": "1.0"
}
}
meta_chain = llm | RunnableLambda(add_metadata)
# 3. 结构化响应
def structured_response(result) -> dict:
"""结构化响应"""
return {
"success": True,
"code": 200,
"message": "处理成功",
"data": {
"answer": result.content,
"confidence": 0.95
},
"error": None
}
# 4. 错误处理
def safe_format(result) -> dict:
"""安全格式化(包含错误处理)"""
try:
if hasattr(result, 'content'):
return {
"success": True,
"data": result.content
}
else:
return {
"success": False,
"error": "无效的结果格式"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
# 5. 分页输出
def paginate_output(results: list, page: int = 1, page_size: int = 10) -> dict:
"""分页输出"""
start = (page - 1) * page_size
end = start + page_size
return {
"data": results[start:end],
"pagination": {
"page": page,
"page_size": page_size,
"total": len(results),
"total_pages": (len(results) + page_size - 1) // page_size
}
}
# 6. 数据转换
def transform_output(result) -> dict:
"""转换输出格式"""
content = result.content
# 提取结构化信息
lines = content.split('\\n')
return {
"summary": lines[0] if lines else "",
"details": lines[1:] if len(lines) > 1 else [],
"word_count": len(content.split()),
"char_count": len(content)
}
# 7. 多语言输出
def multilang_output(result) -> dict:
"""多语言输出"""
return {
"zh": result.content,
"en": translate_to_english(result.content), # 假设有翻译函数
"metadata": {
"default_lang": "zh"
}
}
# 8. 信创环境输出格式化
def xinchuang_format_output(result: dict) -> dict:
"""信创环境输出格式化"""
import time
return {
"code": 0,
"message": "操作成功",
"data": {
"content": result.get("result", ""),
"doc_id": result.get("doc_id", ""),
"analysis": result.get("analysis", {})
},
"system_info": {
"db_type": result.get("db_type", "dameng"),
"llm_provider": result.get("llm_provider", "ollama"),
"server": "xinchuang",
"timestamp": time.time(),
"request_id": f"REQ{int(time.time() * 1000)}"
},
"performance": {
"execution_time_ms": result.get("execution_time", 0),
"db_query_time_ms": result.get("db_time", 0),
"llm_process_time_ms": result.get("llm_time", 0)
}
}
xc_output_chain = RunnableLambda(xinchuang_format_output)
# 完整信创Chain
def xinchuang_process(input_data: dict) -> dict:
"""信创处理"""
start_time = time.time()
# 数据库查询
db_start = time.time()
# ... 数据库操作
db_time = (time.time() - db_start) * 1000
# LLM处理
llm_start = time.time()
# ... LLM处理
llm_time = (time.time() - llm_start) * 1000
execution_time = (time.time() - start_time) * 1000
return {
"result": "处理结果",
"doc_id": input_data.get("doc_id"),
"analysis": {},
"db_type": input_data.get("db_type", "dameng"),
"llm_provider": input_data.get("llm_provider", "ollama"),
"execution_time": execution_time,
"db_time": db_time,
"llm_time": llm_time
}
xc_complete_chain = (
RunnableLambda(xinchuang_process) |
xc_output_chain
)
add_routes(app, xc_complete_chain, path="/xinchuang_formatted")
print("✓ 输出格式化配置完成")
---
2.4 流式响应
01.启用流式输出
a.配置stream方法
a.功能说明
实现Runnable的stream方法,支持流式输出。使用Python生成器(generator)逐步返回数据。LangServe自动将生成器转换为SSE(Server-Sent Events)。流式输出提升用户体验,适用于长文本生成。
b.代码示例
---
from langchain.schema.runnable import Runnable
from typing import Iterator, Any
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langserve import add_routes
# 1. 基础流式Runnable
class StreamingRunnable(Runnable):
"""流式Runnable"""
def invoke(self, input: Any, config: dict = None) -> Any:
"""同步调用"""
return "完整响应"
def stream(self, input: Any, config: dict = None) -> Iterator[Any]:
"""流式调用"""
text = "这是一个流式响应示例"
for char in text:
yield char
stream_runnable = StreamingRunnable()
add_routes(app, stream_runnable, path="/basic_stream")
# 2. LLM流式输出
llm = ChatOpenAI(streaming=True) # 关键:启用streaming
prompt = ChatPromptTemplate.from_template("讲一个关于{topic}的故事")
stream_chain = prompt | llm
add_routes(app, stream_chain, path="/story_stream")
# 客户端调用:
# POST /story_stream/stream
# {"input": {"topic": "AI"}}
# 3. 自定义流式处理
def custom_stream_generator(input_text: str) -> Iterator[str]:
"""自定义流式生成器"""
import time
words = input_text.split()
for word in words:
yield word + " "
time.sleep(0.1) # 模拟延迟
from langchain.schema.runnable import RunnableLambda
def create_stream(input_data: dict) -> Iterator[str]:
"""创建流式输出"""
return custom_stream_generator(input_data["text"])
custom_stream_chain = RunnableLambda(create_stream)
add_routes(app, custom_stream_chain, path="/custom_stream")
# 4. 组合流式Chain
from langchain.schema.output_parser import StrOutputParser
# Prompt -> LLM -> Parser 都支持流式
full_stream_chain = (
ChatPromptTemplate.from_template("翻译为英文:{text}") |
ChatOpenAI(streaming=True) |
StrOutputParser()
)
add_routes(app, full_stream_chain, path="/translate_stream")
# 5. 流式+日志
class LoggingStreamRunnable(Runnable):
"""带日志的流式Runnable"""
def __init__(self, inner_runnable):
self.inner = inner_runnable
def stream(self, input: Any, config: dict = None) -> Iterator[Any]:
"""流式调用+日志"""
print(f"[流式开始] 输入:{input}")
chunk_count = 0
for chunk in self.inner.stream(input, config):
chunk_count += 1
print(f"[Chunk {chunk_count}] {chunk}")
yield chunk
print(f"[流式结束] 总Chunk数:{chunk_count}")
logged_stream = LoggingStreamRunnable(llm)
add_routes(app, logged_stream, path="/logged_stream")
# 6. 流式+缓存
class CachedStreamRunnable(Runnable):
"""带缓存的流式Runnable"""
def __init__(self, inner_runnable):
self.inner = inner_runnable
self.cache = {}
def stream(self, input: Any, config: dict = None) -> Iterator[Any]:
"""流式调用+缓存"""
cache_key = str(input)
if cache_key in self.cache:
# 从缓存返回
print("[缓存命中]")
for chunk in self.cache[cache_key]:
yield chunk
else:
# 生成并缓存
chunks = []
for chunk in self.inner.stream(input, config):
chunks.append(chunk)
yield chunk
self.cache[cache_key] = chunks
cached_stream = CachedStreamRunnable(llm)
add_routes(app, cached_stream, path="/cached_stream")
# 7. 流式错误处理
class SafeStreamRunnable(Runnable):
"""安全的流式Runnable"""
def __init__(self, inner_runnable):
self.inner = inner_runnable
def stream(self, input: Any, config: dict = None) -> Iterator[Any]:
"""流式调用+错误处理"""
try:
for chunk in self.inner.stream(input, config):
yield chunk
except Exception as e:
yield f"\\n[错误:{str(e)}]"
safe_stream = SafeStreamRunnable(llm)
add_routes(app, safe_stream, path="/safe_stream")
# 8. 信创环境流式配置
from langchain.llms import Ollama
# 本地Ollama支持流式
local_llm = Ollama(model="qwen:7b", base_url="http://localhost:11434")
local_prompt = ChatPromptTemplate.from_template("回答:{question}")
local_stream_chain = local_prompt | local_llm
add_routes(app, local_stream_chain, path="/local_stream")
# 混合流式(本地+云端)
class HybridStreamRunnable(Runnable):
"""混合流式Runnable"""
def __init__(self):
self.local_llm = Ollama(model="qwen:7b")
self.cloud_llm = ChatOpenAI(streaming=True)
def stream(self, input: Any, config: dict = None) -> Iterator[Any]:
"""优先本地,失败切换云端"""
try:
# 尝试本地
for chunk in self.local_llm.stream(input, config):
yield chunk
except Exception as e:
# 切换云端
yield "\\n[切换到云端模型]\\n"
for chunk in self.cloud_llm.stream(input, config):
yield chunk
hybrid_stream = HybridStreamRunnable()
add_routes(app, hybrid_stream, path="/hybrid_stream")
print("✓ 流式输出配置完成")
---
b.SSE协议
a.功能说明
LangServe使用SSE(Server-Sent Events)协议传输流式数据。SSE是HTML5标准,单向服务器推送。数据格式为文本,每条消息前缀"data: "。客户端使用EventSource或fetch API接收。
b.代码示例
---
# 1. SSE响应格式
# Content-Type: text/event-stream
#
# event: data
# data: {"content": "第"}
#
# event: data
# data: {"content": "一"}
#
# event: data
# data: {"content": "段"}
#
# event: end
# 2. Python客户端(requests)
import requests
import json
def stream_request(url: str, input_data: dict):
"""流式请求"""
with requests.post(
url,
json={"input": input_data},
stream=True, # 关键:启用流式
timeout=60
) as response:
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
data_str = line_str[6:] # 移除"data: "前缀
try:
data = json.loads(data_str)
if "content" in data:
print(data["content"], end="", flush=True)
except json.JSONDecodeError:
pass
elif line_str == "event: end":
break
print() # 换行
# 使用
stream_request(
"http://localhost:8000/story_stream/stream",
{"topic": "未来"}
)
# 3. Python客户端(httpx + async)
import httpx
import asyncio
async def async_stream_request(url: str, input_data: dict):
"""异步流式请求"""
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
url,
json={"input": input_data},
timeout=60
) as response:
async for line in response.aiter_lines():
if line.startswith('data: '):
try:
data = json.loads(line[6:])
if "content" in data:
print(data["content"], end="", flush=True)
except json.JSONDecodeError:
pass
# 使用
asyncio.run(async_stream_request(
"http://localhost:8000/story_stream/stream",
{"topic": "AI"}
))
# 4. JavaScript客户端(Fetch API)
"""
async function streamRequest(url, inputData) {
const response = await fetch(url, {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({input: inputData})
});
const reader = response.body.getReader();
const decoder = new TextDecoder();
while (true) {
const {done, value} = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
const data = JSON.parse(line.slice(6));
if (data.content) {
document.getElementById('output').textContent += data.content;
}
}
}
}
}
// 使用
streamRequest('http://localhost:8000/chat/stream', {message: '你好'});
"""
# 5. JavaScript客户端(EventSource)
"""
// 注意:EventSource只支持GET,不适用于POST请求
// 如果需要POST,使用Fetch API
const eventSource = new EventSource('/stream_endpoint?param=value');
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data);
console.log(data);
};
eventSource.onerror = (error) => {
console.error('SSE Error:', error);
eventSource.close();
};
"""
# 6. 自定义SSE服务器端
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import json
import time
app = FastAPI()
@app.post("/custom_sse")
async def custom_sse(data: dict):
"""自定义SSE端点"""
async def event_generator():
"""SSE事件生成器"""
message = data.get("message", "")
# 发送开始事件
yield f"event: start\\ndata: {{\"status\": \"processing\"}}\\n\\n"
# 发送数据事件
for char in message:
yield f"event: data\\ndata: {{\"content\": \"{char}\"}}\\n\\n"
await asyncio.sleep(0.1)
# 发送结束事件
yield f"event: end\\ndata: {{\"status\": \"complete\"}}\\n\\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
# 7. SSE错误处理
def safe_sse_generator():
"""安全的SSE生成器"""
try:
yield "event: start\\ndata: {}\\n\\n"
# 处理逻辑...
for i in range(10):
yield f"event: data\\ndata: {{\"num\": {i}}}\\n\\n"
time.sleep(0.1)
yield "event: end\\ndata: {}\\n\\n"
except Exception as e:
error_msg = json.dumps({"error": str(e)})
yield f"event: error\\ndata: {error_msg}\\n\\n"
@app.get("/safe_sse")
async def safe_sse():
return StreamingResponse(
safe_sse_generator(),
media_type="text/event-stream"
)
# 8. 信创环境SSE示例
@app.post("/xinchuang_sse")
async def xinchuang_sse(data: dict):
"""信创环境SSE流式输出"""
async def xc_event_generator():
"""信创事件生成器"""
# 连接本地Ollama
from langchain.llms import Ollama
llm = Ollama(model="qwen:7b")
question = data.get("question", "")
# 发送开始
yield f"event: start\\ndata: {{\"llm\": \"ollama\", \"model\": \"qwen:7b\"}}\\n\\n"
# 流式生成
try:
for chunk in llm.stream(question):
chunk_data = json.dumps({"content": chunk})
yield f"event: data\\ndata: {chunk_data}\\n\\n"
except Exception as e:
error_data = json.dumps({"error": str(e)})
yield f"event: error\\ndata: {error_data}\\n\\n"
# 发送结束
yield f"event: end\\ndata: {{\"status\": \"complete\"}}\\n\\n"
return StreamingResponse(
xc_event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no" # 禁用Nginx缓冲
}
)
print("✓ SSE协议配置完成")
---
02.客户端处理
a.Python客户端
a.功能说明
使用Python作为客户端接收流式响应。requests库的stream=True参数启用流式接收。httpx库提供异步流式支持。逐行解析SSE数据,提取内容。
b.代码示例
---
import requests
import json
# 1. 基础流式客户端
def simple_stream_client(url: str, input_data: dict):
"""简单流式客户端"""
response = requests.post(
url,
json={"input": input_data},
stream=True # 启用流式
)
for line in response.iter_lines():
if line:
print(line.decode('utf-8'))
# 使用
simple_stream_client(
"http://localhost:8000/chat/stream",
{"message": "你好"}
)
# 2. 解析SSE格式
def parse_sse_stream(url: str, input_data: dict):
"""解析SSE流"""
with requests.post(
url,
json={"input": input_data},
stream=True,
timeout=60
) as response:
full_text = ""
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
# 解析data行
if line_str.startswith('data: '):
data_str = line_str[6:]
try:
data = json.loads(data_str)
if "content" in data:
content = data["content"]
print(content, end="", flush=True)
full_text += content
except json.JSONDecodeError:
pass
# 检查结束
elif line_str == "event: end":
break
print() # 换行
return full_text
# 使用
result = parse_sse_stream(
"http://localhost:8000/story/stream",
{"topic": "未来"}
)
print(f"\\n完整文本长度:{len(result)}")
# 3. 带进度显示
def stream_with_progress(url: str, input_data: dict):
"""流式+进度显示"""
import sys
with requests.post(
url,
json={"input": input_data},
stream=True
) as response:
char_count = 0
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
try:
data = json.loads(line_str[6:])
if "content" in data:
content = data["content"]
sys.stdout.write(content)
sys.stdout.flush()
char_count += len(content)
# 每100个字符显示进度
if char_count % 100 == 0:
sys.stdout.write(f"\\n[已生成{char_count}字符]\\n")
except json.JSONDecodeError:
pass
# 4. 异步流式客户端
import httpx
import asyncio
async def async_stream_client(url: str, input_data: dict):
"""异步流式客户端"""
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
url,
json={"input": input_data},
timeout=60.0
) as response:
async for line in response.aiter_lines():
if line.startswith('data: '):
try:
data = json.loads(line[6:])
if "content" in data:
print(data["content"], end="", flush=True)
except:
pass
# 使用
asyncio.run(async_stream_client(
"http://localhost:8000/chat/stream",
{"message": "测试"}
))
# 5. 并发流式请求
async def concurrent_streams(urls: list, inputs: list):
"""并发流式请求"""
async def single_stream(url, input_data, stream_id):
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
url,
json={"input": input_data}
) as response:
print(f"\\n[Stream {stream_id}开始]")
async for line in response.aiter_lines():
if line.startswith('data: '):
try:
data = json.loads(line[6:])
if "content" in data:
print(f"[{stream_id}] {data['content']}", end="")
except:
pass
print(f"\\n[Stream {stream_id}结束]")
tasks = [
single_stream(url, inp, i)
for i, (url, inp) in enumerate(zip(urls, inputs))
]
await asyncio.gather(*tasks)
# 使用
# asyncio.run(concurrent_streams(
# ["http://localhost:8000/chat/stream"] * 3,
# [{"message": f"问题{i}"} for i in range(3)]
# ))
# 6. 错误处理
def robust_stream_client(url: str, input_data: dict, max_retries: int = 3):
"""健壮的流式客户端"""
for attempt in range(max_retries):
try:
with requests.post(
url,
json={"input": input_data},
stream=True,
timeout=30
) as response:
response.raise_for_status()
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if 'error' in line_str.lower():
print(f"\\n[错误] {line_str}")
break
if line_str.startswith('data: '):
try:
data = json.loads(line_str[6:])
if "content" in data:
print(data["content"], end="", flush=True)
except:
pass
return # 成功,退出重试
except requests.exceptions.Timeout:
print(f"\\n[超时] 第{attempt+1}次尝试失败")
if attempt < max_retries - 1:
print("重试中...")
else:
print("达到最大重试次数")
except requests.exceptions.RequestException as e:
print(f"\\n[请求失败] {e}")
break
# 7. 保存流式输出
def save_stream_to_file(url: str, input_data: dict, output_file: str):
"""保存流式输出到文件"""
with requests.post(
url,
json={"input": input_data},
stream=True
) as response:
with open(output_file, 'w', encoding='utf-8') as f:
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
try:
data = json.loads(line_str[6:])
if "content" in data:
content = data["content"]
f.write(content)
f.flush() # 立即写入
# 同时显示
print(content, end="", flush=True)
except:
pass
print(f"\\n输出已保存到:{output_file}")
# 8. 信创环境客户端
def xinchuang_stream_client(url: str, question: str):
"""信创环境流式客户端"""
print(f"问:{question}")
print("答:", end="")
with requests.post(
url,
json={"question": question},
stream=True,
verify=False, # 内网环境可能用自签名证书
timeout=60
) as response:
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
# 解析事件
if line_str.startswith('event: start'):
# 开始事件
pass
elif line_str.startswith('data: '):
try:
data = json.loads(line_str[6:])
if "content" in data:
print(data["content"], end="", flush=True)
elif "error" in data:
print(f"\\n[错误] {data['error']}")
break
except:
pass
elif line_str.startswith('event: end'):
# 结束事件
break
print("\\n")
# 使用
xinchuang_stream_client(
"http://internal-server:8000/local/chat/stream",
"介绍一下信创技术"
)
print("✓ Python客户端配置完成")
---
3 客户端调用
3.1 RemoteRunnable
01.客户端连接
a.创建RemoteRunnable
a.功能说明
RemoteRunnable是LangServe的客户端类,连接远程Chain服务。提供与本地Runnable相同的接口(invoke、batch、stream)。自动处理HTTP请求、序列化、错误重试。简化客户端开发,像调用本地对象一样调用远程服务。
b.代码示例
---
from langserve import RemoteRunnable
# 1. 基础连接
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 使用remote_chain就像本地Runnable
result = remote_chain.invoke({"message": "你好"})
print(result)
# 输出:{"content": "你好!有什么可以帮助你的吗?"}
# 2. 带超时的连接
remote_chain = RemoteRunnable(
"http://localhost:8000/chat",
timeout=30.0 # 30秒超时
)
# 3. 自定义headers
remote_chain = RemoteRunnable(
"http://localhost:8000/chat",
headers={
"Authorization": "Bearer your_token",
"X-Custom-Header": "value"
}
)
# 4. HTTPS连接
remote_chain = RemoteRunnable(
"https://api.example.com/chat",
verify=True # 验证SSL证书
)
# 5. 代理配置
remote_chain = RemoteRunnable(
"http://api.company.com/chat",
proxies={
"http": "http://proxy.company.com:8080",
"https": "https://proxy.company.com:8080"
}
)
# 6. 连接多个服务
chat_chain = RemoteRunnable("http://localhost:8000/chat")
translate_chain = RemoteRunnable("http://localhost:8000/translate")
summary_chain = RemoteRunnable("http://localhost:8000/summary")
# 分别调用
chat_result = chat_chain.invoke({"message": "你好"})
translate_result = translate_chain.invoke({"text": "Hello", "lang": "zh"})
summary_result = summary_chain.invoke({"content": "长文本..."})
# 7. 连接池配置
import httpx
# 创建自定义HTTP客户端
client = httpx.Client(
timeout=60.0,
limits=httpx.Limits(
max_connections=100, # 最大连接数
max_keepalive_connections=20 # 最大keep-alive连接数
)
)
# 注意:当前RemoteRunnable不直接支持自定义client
# 但可以通过继承扩展
# 8. 连接验证
def test_connection(url: str) -> bool:
"""测试连接"""
try:
remote_chain = RemoteRunnable(url, timeout=5.0)
# 发送测试请求
result = remote_chain.invoke({"test": "connection"})
print(f"✓ 连接成功:{url}")
return True
except Exception as e:
print(f"✗ 连接失败:{url} - {e}")
return False
# 使用
test_connection("http://localhost:8000/chat")
# 9. 信创环境连接
# 连接本地Ollama服务
local_chain = RemoteRunnable(
"http://localhost:8000/local/chat",
timeout=60.0,
verify=False # 内网环境可能使用自签名证书
)
# 连接达梦数据库服务
db_chain = RemoteRunnable(
"http://internal-server:8000/xinchuang/dameng/query"
)
# 测试连接
try:
result = local_chain.invoke({"question": "测试连接"})
print(f"✓ 本地服务连接成功")
print(f"响应:{result}")
except Exception as e:
print(f"✗ 连接失败:{e}")
# 10. 连接管理器
class RemoteChainManager:
"""远程Chain管理器"""
def __init__(self, base_url: str):
self.base_url = base_url
self.chains = {}
def get_chain(self, path: str) -> RemoteRunnable:
"""获取或创建Chain连接"""
if path not in self.chains:
url = f"{self.base_url}{path}"
self.chains[path] = RemoteRunnable(url)
return self.chains[path]
def invoke_chain(self, path: str, input_data: dict):
"""调用Chain"""
chain = self.get_chain(path)
return chain.invoke(input_data)
# 使用
manager = RemoteChainManager("http://localhost:8000")
chat_result = manager.invoke_chain("/chat", {"message": "你好"})
translate_result = manager.invoke_chain("/translate", {"text": "Hello"})
print("✓ RemoteRunnable连接配置完成")
---
b.配置选项
a.功能说明
RemoteRunnable支持多种配置选项,定制客户端行为。设置超时时间、重试策略、请求头。配置代理、SSL验证等网络参数。合理配置提升客户端的可靠性和性能。
b.代码示例
---
from langserve import RemoteRunnable
import time
# 1. 超时配置
remote_chain = RemoteRunnable(
"http://localhost:8000/chat",
timeout=30.0 # 请求超时30秒
)
# 不同操作使用不同超时
try:
# 快速请求
result = remote_chain.invoke(
{"message": "快速问题"},
config={"timeout": 5.0}
)
except TimeoutError:
print("请求超时")
# 2. 重试配置
from tenacity import retry, stop_after_attempt, wait_exponential
@retry(
stop=stop_after_attempt(3), # 最多重试3次
wait=wait_exponential(multiplier=1, min=2, max=10) # 指数退避
)
def invoke_with_retry(chain, input_data):
"""带重试的调用"""
return chain.invoke(input_data)
# 使用
remote_chain = RemoteRunnable("http://localhost:8000/chat")
result = invoke_with_retry(remote_chain, {"message": "测试"})
# 3. 认证配置
# API Key认证
remote_chain = RemoteRunnable(
"http://localhost:8000/protected/chat",
headers={"X-API-Key": "your_api_key"}
)
# Bearer Token认证
remote_chain = RemoteRunnable(
"http://localhost:8000/protected/chat",
headers={"Authorization": "Bearer your_jwt_token"}
)
# Basic认证
import base64
credentials = base64.b64encode(b"username:password").decode()
remote_chain = RemoteRunnable(
"http://localhost:8000/chat",
headers={"Authorization": f"Basic {credentials}"}
)
# 4. 自定义headers
remote_chain = RemoteRunnable(
"http://localhost:8000/chat",
headers={
"User-Agent": "MyApp/1.0",
"X-Request-ID": "req-12345",
"X-Client-Version": "1.0.0",
"Accept": "application/json"
}
)
# 5. 代理配置
# HTTP代理
remote_chain = RemoteRunnable(
"http://api.example.com/chat",
proxies={
"http": "http://proxy.company.com:8080"
}
)
# HTTPS代理
remote_chain = RemoteRunnable(
"https://api.example.com/chat",
proxies={
"https": "https://proxy.company.com:8443"
}
)
# SOCKS代理
remote_chain = RemoteRunnable(
"http://api.example.com/chat",
proxies={
"http": "socks5://proxy.company.com:1080"
}
)
# 6. SSL/TLS配置
# 禁用SSL验证(不推荐,仅用于测试)
remote_chain = RemoteRunnable(
"https://localhost:8000/chat",
verify=False
)
# 使用自定义CA证书
remote_chain = RemoteRunnable(
"https://internal-server:8000/chat",
verify="/path/to/ca-bundle.crt"
)
# 7. 日志配置
import logging
logging.basicConfig(level=logging.DEBUG)
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 将记录详细的请求/响应日志
# 8. 配置类封装
class RemoteChainConfig:
"""远程Chain配置"""
def __init__(
self,
url: str,
timeout: float = 30.0,
max_retries: int = 3,
api_key: str = None,
proxies: dict = None,
verify_ssl: bool = True
):
self.url = url
self.timeout = timeout
self.max_retries = max_retries
self.api_key = api_key
self.proxies = proxies
self.verify_ssl = verify_ssl
def create_chain(self) -> RemoteRunnable:
"""创建配置好的RemoteRunnable"""
headers = {}
if self.api_key:
headers["X-API-Key"] = self.api_key
return RemoteRunnable(
self.url,
timeout=self.timeout,
headers=headers,
proxies=self.proxies,
verify=self.verify_ssl
)
# 使用
config = RemoteChainConfig(
url="http://localhost:8000/chat",
timeout=60.0,
api_key="my_api_key",
max_retries=3
)
remote_chain = config.create_chain()
# 9. 环境变量配置
import os
REMOTE_URL = os.getenv("LANGSERVE_URL", "http://localhost:8000")
REMOTE_TIMEOUT = float(os.getenv("LANGSERVE_TIMEOUT", "30.0"))
API_KEY = os.getenv("LANGSERVE_API_KEY")
headers = {}
if API_KEY:
headers["X-API-Key"] = API_KEY
remote_chain = RemoteRunnable(
f"{REMOTE_URL}/chat",
timeout=REMOTE_TIMEOUT,
headers=headers
)
# 10. 信创环境配置
class XinchuangRemoteConfig:
"""信创环境远程配置"""
def __init__(
self,
internal_url: str = "http://internal-server:8000",
use_local_llm: bool = True,
db_type: str = "dameng",
timeout: float = 60.0
):
self.internal_url = internal_url
self.use_local_llm = use_local_llm
self.db_type = db_type
self.timeout = timeout
def get_chat_chain(self) -> RemoteRunnable:
"""获取聊天Chain"""
path = "/local/chat" if self.use_local_llm else "/cloud/chat"
return RemoteRunnable(
f"{self.internal_url}{path}",
timeout=self.timeout,
verify=False # 内网自签名证书
)
def get_db_chain(self) -> RemoteRunnable:
"""获取数据库Chain"""
return RemoteRunnable(
f"{self.internal_url}/xinchuang/{self.db_type}/query",
timeout=self.timeout,
verify=False
)
# 使用
xc_config = XinchuangRemoteConfig(
internal_url="http://192.168.1.100:8000",
use_local_llm=True,
db_type="dameng",
timeout=60.0
)
chat_chain = xc_config.get_chat_chain()
db_chain = xc_config.get_db_chain()
# 测试
try:
result = chat_chain.invoke({"question": "测试信创环境"})
print(f"✓ 信创环境连接成功")
print(f"响应:{result}")
except Exception as e:
print(f"✗ 连接失败:{e}")
print("✓ RemoteRunnable配置完成")
---
02.错误处理
a.异常捕获
a.功能说明
RemoteRunnable调用可能失败(网络错误、服务器错误、超时)。捕获并处理异常,提供友好的错误信息。实现重试机制,提高可靠性。错误处理是构建健壮客户端的关键。
b.代码示例
---
from langserve import RemoteRunnable
import requests
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 1. 基础异常捕获
try:
result = remote_chain.invoke({"message": "你好"})
print(result)
except Exception as e:
print(f"调用失败:{e}")
# 2. 特定异常处理
try:
result = remote_chain.invoke({"message": "测试"})
except requests.exceptions.Timeout:
print("请求超时,请稍后重试")
except requests.exceptions.ConnectionError:
print("无法连接到服务器")
except requests.exceptions.HTTPError as e:
print(f"HTTP错误:{e.response.status_code}")
except Exception as e:
print(f"未知错误:{e}")
# 3. 带重试的调用
def invoke_with_retry(chain, input_data, max_retries=3):
"""带重试的调用"""
import time
for attempt in range(max_retries):
try:
return chain.invoke(input_data)
except requests.exceptions.Timeout:
if attempt < max_retries - 1:
wait_time = 2 ** attempt # 指数退避
print(f"超时,{wait_time}秒后重试...")
time.sleep(wait_time)
else:
raise
except requests.exceptions.ConnectionError:
if attempt < max_retries - 1:
print(f"连接失败,重试中({attempt+1}/{max_retries})...")
time.sleep(2)
else:
raise
# 使用
try:
result = invoke_with_retry(remote_chain, {"message": "测试"})
print(result)
except Exception as e:
print(f"所有重试失败:{e}")
# 4. 优雅降级
def invoke_with_fallback(primary_chain, fallback_chain, input_data):
"""主服务失败切换到备用服务"""
try:
return primary_chain.invoke(input_data)
except Exception as e:
print(f"主服务失败:{e}")
print("切换到备用服务...")
return fallback_chain.invoke(input_data)
# 使用
primary = RemoteRunnable("http://server1:8000/chat")
fallback = RemoteRunnable("http://server2:8000/chat")
result = invoke_with_fallback(primary, fallback, {"message": "测试"})
# 5. 详细错误信息
def safe_invoke(chain, input_data):
"""安全调用,返回详细错误"""
try:
result = chain.invoke(input_data)
return {"success": True, "data": result}
except requests.exceptions.Timeout:
return {
"success": False,
"error": "timeout",
"message": "请求超时"
}
except requests.exceptions.ConnectionError:
return {
"success": False,
"error": "connection_error",
"message": "无法连接到服务器"
}
except requests.exceptions.HTTPError as e:
return {
"success": False,
"error": "http_error",
"status_code": e.response.status_code,
"message": e.response.text
}
except Exception as e:
return {
"success": False,
"error": "unknown",
"message": str(e)
}
# 使用
result = safe_invoke(remote_chain, {"message": "测试"})
if result["success"]:
print(f"成功:{result['data']}")
else:
print(f"失败:{result['error']} - {result['message']}")
# 6. 错误日志
import logging
logger = logging.getLogger(__name__)
def logged_invoke(chain, input_data):
"""带日志的调用"""
try:
logger.info(f"调用Chain:{input_data}")
result = chain.invoke(input_data)
logger.info(f"调用成功:{result}")
return result
except Exception as e:
logger.error(f"调用失败:{e}", exc_info=True)
raise
# 7. 超时处理
import signal
class TimeoutError(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutError("操作超时")
def invoke_with_timeout(chain, input_data, timeout_seconds=30):
"""带超时控制的调用"""
# 设置信号处理(仅Unix系统)
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout_seconds)
try:
result = chain.invoke(input_data)
signal.alarm(0) # 取消超时
return result
except TimeoutError:
print(f"操作超时({timeout_seconds}秒)")
raise
finally:
signal.alarm(0) # 确保取消超时
# 8. 错误统计
class ErrorStats:
"""错误统计"""
def __init__(self):
self.total_requests = 0
self.success_count = 0
self.error_count = 0
self.errors = {}
def record_success(self):
self.total_requests += 1
self.success_count += 1
def record_error(self, error_type: str):
self.total_requests += 1
self.error_count += 1
self.errors[error_type] = self.errors.get(error_type, 0) + 1
def get_stats(self):
return {
"total": self.total_requests,
"success": self.success_count,
"error": self.error_count,
"error_rate": self.error_count / self.total_requests if self.total_requests > 0 else 0,
"errors": self.errors
}
stats = ErrorStats()
def tracked_invoke(chain, input_data):
"""跟踪统计的调用"""
try:
result = chain.invoke(input_data)
stats.record_success()
return result
except Exception as e:
error_type = type(e).__name__
stats.record_error(error_type)
raise
# 使用
for i in range(10):
try:
tracked_invoke(remote_chain, {"message": f"测试{i}"})
except:
pass
print(stats.get_stats())
# 9. 信创环境错误处理
def xinchuang_safe_invoke(chain, input_data):
"""信创环境安全调用"""
try:
result = chain.invoke(input_data)
return {
"code": 0,
"message": "成功",
"data": result
}
except requests.exceptions.Timeout:
return {
"code": -1,
"message": "请求超时,请检查网络连接",
"data": None
}
except requests.exceptions.ConnectionError:
return {
"code": -2,
"message": "无法连接到内部服务器,请联系管理员",
"data": None
}
except Exception as e:
return {
"code": -99,
"message": f"系统错误:{str(e)}",
"data": None
}
# 使用
xc_chain = RemoteRunnable("http://internal-server:8000/local/chat")
result = xinchuang_safe_invoke(xc_chain, {"question": "测试"})
if result["code"] == 0:
print(f"成功:{result['data']}")
else:
print(f"失败[{result['code']}]:{result['message']}")
print("✓ 错误处理配置完成")
---
3.2 同步调用
01.invoke方法
a.基础用法
a.功能说明
invoke是RemoteRunnable的基础同步调用方法,发送单个请求并等待响应。阻塞式调用,适用于简单的请求-响应场景。返回完整结果,不支持流式输出。使用简单,适合快速原型和简单脚本。
b.代码示例
---
from langserve import RemoteRunnable
# 1. 最简单调用
remote_chain = RemoteRunnable("http://localhost:8000/chat")
result = remote_chain.invoke({"message": "你好"})
print(result)
# 输出:{"content": "你好!有什么可以帮助你的吗?"}
# 2. 提取响应内容
response = remote_chain.invoke({"message": "什么是AI?"})
if "content" in response:
print(f"AI回复:{response['content']}")
# 3. 带参数调用
result = remote_chain.invoke({
"message": "讲个笑话",
"temperature": 0.9,
"max_tokens": 100
})
# 4. 多次调用
questions = [
"什么是Python?",
"什么是机器学习?",
"什么是深度学习?"
]
for q in questions:
result = remote_chain.invoke({"message": q})
print(f"Q: {q}")
print(f"A: {result.get('content', result)}")
print()
# 5. 不同Chain调用
chat_chain = RemoteRunnable("http://localhost:8000/chat")
translate_chain = RemoteRunnable("http://localhost:8000/translate")
summary_chain = RemoteRunnable("http://localhost:8000/summary")
chat_result = chat_chain.invoke({"message": "你好"})
translate_result = translate_chain.invoke({"text": "Hello", "target_lang": "zh"})
summary_result = summary_chain.invoke({"content": "长文本内容..."})
# 6. 带配置调用
result = remote_chain.invoke(
{"message": "测试"},
config={
"configurable": {
"temperature": 0.7,
"model": "gpt-4"
},
"tags": ["production"],
"metadata": {"user_id": "123"}
}
)
# 7. 结果验证
def safe_invoke(chain, input_data):
"""安全调用并验证结果"""
result = chain.invoke(input_data)
# 验证结果格式
if not isinstance(result, dict):
raise ValueError("无效的响应格式")
if "content" not in result:
raise ValueError("响应缺少content字段")
return result["content"]
# 使用
content = safe_invoke(remote_chain, {"message": "测试"})
print(content)
# 8. 性能测量
import time
def timed_invoke(chain, input_data):
"""测量调用时间"""
start = time.time()
result = chain.invoke(input_data)
elapsed = time.time() - start
print(f"调用耗时:{elapsed:.2f}秒")
return result
# 使用
result = timed_invoke(remote_chain, {"message": "测试性能"})
# 9. 信创环境调用
# 本地模型调用
local_chain = RemoteRunnable("http://localhost:8000/local/chat")
result = local_chain.invoke({
"question": "介绍一下信创技术",
"llm_provider": "ollama",
"model": "qwen:7b"
})
print(f"本地模型响应:{result}")
# 数据库查询调用
db_chain = RemoteRunnable("http://localhost:8000/xinchuang/dameng/query")
result = db_chain.invoke({
"table": "documents",
"condition": "id = 'DOC001'"
})
print(f"查询结果:{result}")
# 10. 批量顺序调用
def sequential_invoke(chain, inputs: list):
"""顺序调用多个输入"""
results = []
for i, input_data in enumerate(inputs):
print(f"处理 {i+1}/{len(inputs)}...")
result = chain.invoke(input_data)
results.append(result)
return results
# 使用
inputs = [
{"message": "问题1"},
{"message": "问题2"},
{"message": "问题3"}
]
results = sequential_invoke(remote_chain, inputs)
print(f"完成{len(results)}个调用")
print("✓ invoke方法配置完成")
---
b.响应处理
a.功能说明
处理invoke返回的响应数据,提取所需信息。解析JSON响应,处理嵌套结构。错误检查和数据验证。格式化输出,适配业务需求。
b.代码示例
---
from langserve import RemoteRunnable
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 1. 基础响应处理
result = remote_chain.invoke({"message": "你好"})
# 提取content
if isinstance(result, dict) and "content" in result:
content = result["content"]
print(f"AI回复:{content}")
else:
print(f"原始响应:{result}")
# 2. 处理复杂响应
result = remote_chain.invoke({"message": "分析这段文本"})
# 响应可能是嵌套结构
# {
# "output": {
# "content": "分析结果",
# "metadata": {...}
# }
# }
if "output" in result:
output = result["output"]
content = output.get("content", "")
metadata = output.get("metadata", {})
print(f"内容:{content}")
print(f"元数据:{metadata}")
# 3. 提取Token信息
result = remote_chain.invoke({"message": "测试"})
# 可能包含Token使用信息
tokens = result.get("metadata", {}).get("tokens", {})
if tokens:
print(f"输入Token:{tokens.get('prompt_tokens', 0)}")
print(f"输出Token:{tokens.get('completion_tokens', 0)}")
print(f"总Token:{tokens.get('total_tokens', 0)}")
# 4. 响应转换
def transform_response(result: dict) -> str:
"""转换响应为纯文本"""
# 尝试多种可能的字段
if "content" in result:
return result["content"]
elif "output" in result:
output = result["output"]
if isinstance(output, dict):
return output.get("content", str(output))
return str(output)
else:
return str(result)
# 使用
result = remote_chain.invoke({"message": "测试"})
text = transform_response(result)
print(text)
# 5. 结构化输出
class Response:
"""响应对象"""
def __init__(self, raw_data: dict):
self.raw = raw_data
self.content = self._extract_content()
self.metadata = self._extract_metadata()
def _extract_content(self) -> str:
"""提取内容"""
if "content" in self.raw:
return self.raw["content"]
elif "output" in self.raw:
output = self.raw["output"]
if isinstance(output, dict):
return output.get("content", "")
return ""
def _extract_metadata(self) -> dict:
"""提取元数据"""
return self.raw.get("metadata", {})
def __str__(self):
return self.content
# 使用
result = remote_chain.invoke({"message": "测试"})
response = Response(result)
print(f"内容:{response.content}")
print(f"元数据:{response.metadata}")
# 6. 批量响应处理
def process_batch_results(results: list) -> list:
"""处理批量结果"""
processed = []
for i, result in enumerate(results):
try:
content = result.get("content", str(result))
processed.append({
"index": i,
"content": content,
"success": True
})
except Exception as e:
processed.append({
"index": i,
"error": str(e),
"success": False
})
return processed
# 7. 错误响应处理
def safe_process_response(result: dict) -> dict:
"""安全处理响应"""
try:
# 检查错误
if "error" in result:
return {
"success": False,
"error": result["error"]
}
# 提取内容
content = result.get("content", result.get("output", ""))
return {
"success": True,
"content": content
}
except Exception as e:
return {
"success": False,
"error": f"处理失败:{e}"
}
# 使用
result = remote_chain.invoke({"message": "测试"})
processed = safe_process_response(result)
if processed["success"]:
print(processed["content"])
else:
print(f"错误:{processed['error']}")
# 8. 响应缓存
class ResponseCache:
"""响应缓存"""
def __init__(self):
self.cache = {}
def get(self, key: str):
"""获取缓存"""
return self.cache.get(key)
def set(self, key: str, value):
"""设置缓存"""
self.cache[key] = value
def invoke_with_cache(self, chain, input_data):
"""带缓存的调用"""
cache_key = str(input_data)
# 检查缓存
cached = self.get(cache_key)
if cached:
print("[缓存命中]")
return cached
# 调用并缓存
result = chain.invoke(input_data)
self.set(cache_key, result)
return result
cache = ResponseCache()
# 第一次调用
result1 = cache.invoke_with_cache(remote_chain, {"message": "测试"})
# 第二次调用(使用缓存)
result2 = cache.invoke_with_cache(remote_chain, {"message": "测试"})
# 9. 响应日志
import logging
logger = logging.getLogger(__name__)
def logged_invoke(chain, input_data):
"""记录响应的调用"""
logger.info(f"请求:{input_data}")
result = chain.invoke(input_data)
logger.info(f"响应:{result}")
# 记录统计
if "metadata" in result:
tokens = result["metadata"].get("tokens", {})
logger.info(f"Token使用:{tokens}")
return result
# 10. 信创环境响应处理
def xinchuang_process_response(result: dict) -> dict:
"""信创环境响应处理"""
# 检查响应格式
if "code" in result:
# 信创标准响应格式
# {"code": 0, "message": "成功", "data": {...}}
if result["code"] == 0:
return {
"success": True,
"data": result.get("data", {})
}
else:
return {
"success": False,
"error_code": result["code"],
"error_message": result.get("message", "未知错误")
}
# 通用格式
return {
"success": True,
"data": result
}
# 使用
xc_chain = RemoteRunnable("http://internal-server:8000/local/chat")
result = xc_chain.invoke({"question": "测试"})
processed = xinchuang_process_response(result)
if processed["success"]:
print(f"成功:{processed['data']}")
else:
print(f"失败[{processed['error_code']}]:{processed['error_message']}")
print("✓ 响应处理配置完成")
---
02.性能优化
a.连接复用
a.功能说明
复用HTTP连接,减少连接建立开销。使用连接池管理多个连接。提升并发性能和吞吐量。合理配置连接池参数,平衡资源和性能。
b.代码示例
---
from langserve import RemoteRunnable
import httpx
import time
# 1. 基础连接复用
# RemoteRunnable内部自动复用连接
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 多次调用会复用连接
for i in range(10):
result = remote_chain.invoke({"message": f"消息{i}"})
# 2. 显式连接池(使用httpx)
# 创建客户端with连接池
client = httpx.Client(
limits=httpx.Limits(
max_connections=100, # 最大连接数
max_keepalive_connections=20 # 保持活跃的连接数
)
)
# 注意:RemoteRunnable不直接支持自定义client
# 但可以通过继承扩展
# 3. 连接池管理器
class ConnectionPoolManager:
"""连接池管理器"""
def __init__(self, max_connections=100):
self.chains = {}
self.max_connections = max_connections
def get_chain(self, url: str) -> RemoteRunnable:
"""获取或创建Chain(复用)"""
if url not in self.chains:
self.chains[url] = RemoteRunnable(url)
return self.chains[url]
def invoke(self, url: str, input_data: dict):
"""调用Chain"""
chain = self.get_chain(url)
return chain.invoke(input_data)
# 使用
pool = ConnectionPoolManager()
# 多次调用,复用同一个RemoteRunnable实例
for i in range(10):
result = pool.invoke(
"http://localhost:8000/chat",
{"message": f"消息{i}"}
)
# 4. 性能对比
def benchmark_connection_reuse():
"""测试连接复用性能"""
url = "http://localhost:8000/chat"
n = 20
# 方式1:每次创建新实例
start = time.time()
for i in range(n):
chain = RemoteRunnable(url) # 新实例
chain.invoke({"message": f"测试{i}"})
no_reuse_time = time.time() - start
# 方式2:复用实例
start = time.time()
chain = RemoteRunnable(url) # 复用实例
for i in range(n):
chain.invoke({"message": f"测试{i}"})
reuse_time = time.time() - start
print(f"不复用耗时:{no_reuse_time:.2f}秒")
print(f"复用耗时:{reuse_time:.2f}秒")
print(f"性能提升:{(no_reuse_time/reuse_time):.2f}倍")
# 运行测试
# benchmark_connection_reuse()
# 5. 连接预热
def warmup_connection(chain, warmup_count=3):
"""预热连接"""
print("预热连接中...")
for i in range(warmup_count):
try:
chain.invoke({"message": "预热"})
except:
pass
print("预热完成")
# 使用
remote_chain = RemoteRunnable("http://localhost:8000/chat")
warmup_connection(remote_chain)
# 6. 并发控制
import threading
from queue import Queue
class ConcurrentInvoker:
"""并发调用器"""
def __init__(self, chain, max_workers=5):
self.chain = chain
self.max_workers = max_workers
def invoke_concurrent(self, inputs: list):
"""并发调用"""
results = []
results_lock = threading.Lock()
def worker(input_data, index):
result = self.chain.invoke(input_data)
with results_lock:
results.append((index, result))
threads = []
for i, input_data in enumerate(inputs[:self.max_workers]):
t = threading.Thread(target=worker, args=(input_data, i))
t.start()
threads.append(t)
for t in threads:
t.join()
# 按索引排序
results.sort(key=lambda x: x[0])
return [r[1] for r in results]
# 使用
invoker = ConcurrentInvoker(remote_chain, max_workers=5)
results = invoker.invoke_concurrent([
{"message": f"问题{i}"} for i in range(5)
])
# 7. 连接健康检查
def check_connection_health(chain, timeout=5):
"""检查连接健康"""
try:
start = time.time()
chain.invoke({"message": "健康检查"})
latency = time.time() - start
return {
"healthy": True,
"latency_ms": latency * 1000
}
except Exception as e:
return {
"healthy": False,
"error": str(e)
}
# 使用
health = check_connection_health(remote_chain)
if health["healthy"]:
print(f"✓ 连接健康,延迟:{health['latency_ms']:.2f}ms")
else:
print(f"✗ 连接异常:{health['error']}")
# 8. 连接超时管理
class TimeoutManager:
"""超时管理器"""
def __init__(self, default_timeout=30):
self.default_timeout = default_timeout
def invoke_with_adaptive_timeout(self, chain, input_data, base_timeout=None):
"""自适应超时调用"""
timeout = base_timeout or self.default_timeout
# 根据输入大小调整超时
input_size = len(str(input_data))
if input_size > 1000:
timeout *= 2 # 大输入加倍超时
try:
# 设置超时
result = chain.invoke(input_data)
return result
except Exception as e:
if "timeout" in str(e).lower():
print(f"超时({timeout}秒),尝试更长超时...")
# 重试with更长超时
return chain.invoke(input_data)
raise
# 9. 信创环境连接优化
class XinchuangConnectionPool:
"""信创环境连接池"""
def __init__(self, internal_url: str):
self.internal_url = internal_url
self.local_chain = None
self.db_chain = None
def get_local_chain(self) -> RemoteRunnable:
"""获取本地LLM Chain(复用)"""
if self.local_chain is None:
self.local_chain = RemoteRunnable(
f"{self.internal_url}/local/chat",
timeout=60.0,
verify=False
)
return self.local_chain
def get_db_chain(self) -> RemoteRunnable:
"""获取数据库Chain(复用)"""
if self.db_chain is None:
self.db_chain = RemoteRunnable(
f"{self.internal_url}/xinchuang/dameng/query",
timeout=30.0,
verify=False
)
return self.db_chain
def invoke_local(self, input_data: dict):
"""调用本地LLM"""
chain = self.get_local_chain()
return chain.invoke(input_data)
def invoke_db(self, input_data: dict):
"""调用数据库"""
chain = self.get_db_chain()
return chain.invoke(input_data)
# 使用
xc_pool = XinchuangConnectionPool("http://internal-server:8000")
# 多次调用,复用连接
for i in range(5):
result = xc_pool.invoke_local({"question": f"问题{i}"})
print(f"问题{i}完成")
print("✓ 连接复用配置完成")
---
3.3 异步调用
01.ainvoke方法
a.异步基础
a.功能说明
ainvoke是RemoteRunnable的异步调用方法,使用Python asyncio。非阻塞式调用,提升并发性能。适用于高并发场景,如Web服务器。需要在async函数中使用await调用。
b.代码示例
---
from langserve import RemoteRunnable
import asyncio
# 1. 基础异步调用
async def basic_async_call():
"""基础异步调用"""
remote_chain = RemoteRunnable("http://localhost:8000/chat")
result = await remote_chain.ainvoke({"message": "你好"})
print(result)
# 运行
asyncio.run(basic_async_call())
# 2. 多个异步调用
async def multiple_async_calls():
"""多个异步调用"""
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 顺序调用
result1 = await remote_chain.ainvoke({"message": "问题1"})
result2 = await remote_chain.ainvoke({"message": "问题2"})
result3 = await remote_chain.ainvoke({"message": "问题3"})
print(result1, result2, result3)
asyncio.run(multiple_async_calls())
# 3. 并发调用(gather)
async def concurrent_calls():
"""并发调用"""
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 并发执行多个请求
results = await asyncio.gather(
remote_chain.ainvoke({"message": "问题1"}),
remote_chain.ainvoke({"message": "问题2"}),
remote_chain.ainvoke({"message": "问题3"})
)
for i, result in enumerate(results):
print(f"结果{i+1}:{result}")
asyncio.run(concurrent_calls())
# 4. 带超时的异步调用
async def async_call_with_timeout():
"""带超时的异步调用"""
remote_chain = RemoteRunnable("http://localhost:8000/chat")
try:
result = await asyncio.wait_for(
remote_chain.ainvoke({"message": "测试"}),
timeout=5.0 # 5秒超时
)
print(result)
except asyncio.TimeoutError:
print("请求超时")
asyncio.run(async_call_with_timeout())
# 5. 异步错误处理
async def async_with_error_handling():
"""异步错误处理"""
remote_chain = RemoteRunnable("http://localhost:8000/chat")
try:
result = await remote_chain.ainvoke({"message": "测试"})
print(result)
except Exception as e:
print(f"调用失败:{e}")
asyncio.run(async_with_error_handling())
# 6. 异步重试
async def async_retry(chain, input_data, max_retries=3):
"""异步重试"""
for attempt in range(max_retries):
try:
return await chain.ainvoke(input_data)
except Exception as e:
if attempt < max_retries - 1:
wait_time = 2 ** attempt
print(f"失败,{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
else:
raise
# 使用
async def use_retry():
remote_chain = RemoteRunnable("http://localhost:8000/chat")
result = await async_retry(remote_chain, {"message": "测试"})
print(result)
asyncio.run(use_retry())
# 7. 批量异步处理
async def batch_async_process(inputs: list):
"""批量异步处理"""
remote_chain = RemoteRunnable("http://localhost:8000/chat")
tasks = [
remote_chain.ainvoke(input_data)
for input_data in inputs
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果和异常
for i, result in enumerate(results):
if isinstance(result, Exception):
print(f"输入{i}失败:{result}")
else:
print(f"输入{i}成功:{result}")
return results
# 使用
inputs = [{"message": f"问题{i}"} for i in range(5)]
asyncio.run(batch_async_process(inputs))
# 8. 异步上下文管理
class AsyncChainManager:
"""异步Chain管理器"""
def __init__(self, url: str):
self.chain = RemoteRunnable(url)
async def __aenter__(self):
"""进入上下文"""
print("初始化Chain...")
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""退出上下文"""
print("清理资源...")
async def invoke(self, input_data: dict):
"""调用"""
return await self.chain.ainvoke(input_data)
# 使用
async def use_context_manager():
async with AsyncChainManager("http://localhost:8000/chat") as manager:
result = await manager.invoke({"message": "测试"})
print(result)
asyncio.run(use_context_manager())
# 9. 信创环境异步调用
async def xinchuang_async_call():
"""信创环境异步调用"""
# 本地模型异步调用
local_chain = RemoteRunnable(
"http://internal-server:8000/local/chat",
verify=False
)
result = await local_chain.ainvoke({
"question": "介绍信创技术",
"llm_provider": "ollama"
})
print(f"本地模型响应:{result}")
asyncio.run(xinchuang_async_call())
# 10. 完整异步应用示例
class AsyncChatApp:
"""异步聊天应用"""
def __init__(self, url: str):
self.chain = RemoteRunnable(url)
async def chat(self, message: str) -> str:
"""聊天"""
result = await self.chain.ainvoke({"message": message})
return result.get("content", str(result))
async def chat_loop(self):
"""聊天循环"""
print("聊天开始(输入'exit'退出)")
while True:
# 在异步程序中,input是阻塞的
# 实际应用应使用异步输入
message = input("你:")
if message.lower() == "exit":
break
try:
response = await self.chat(message)
print(f"AI:{response}")
except Exception as e:
print(f"错误:{e}")
# 使用
# app = AsyncChatApp("http://localhost:8000/chat")
# asyncio.run(app.chat_loop())
print("✓ ainvoke方法配置完成")
---
b.并发控制
a.功能说明
控制异步调用的并发数量,避免资源耗尽。使用信号量(Semaphore)限制并发。实现任务队列,平衡性能和资源。并发控制是高性能异步应用的关键。
b.代码示例
---
import asyncio
from langserve import RemoteRunnable
# 1. 使用Semaphore限制并发
async def limited_concurrent_calls(inputs: list, max_concurrent=5):
"""限制并发数的调用"""
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 创建信号量
semaphore = asyncio.Semaphore(max_concurrent)
async def limited_invoke(input_data):
async with semaphore: # 获取信号量
return await remote_chain.ainvoke(input_data)
# 并发执行,但最多max_concurrent个同时运行
tasks = [limited_invoke(inp) for inp in inputs]
results = await asyncio.gather(*tasks)
return results
# 使用
inputs = [{"message": f"问题{i}"} for i in range(20)]
results = asyncio.run(limited_concurrent_calls(inputs, max_concurrent=5))
print(f"完成{len(results)}个调用")
# 2. 任务队列
class AsyncTaskQueue:
"""异步任务队列"""
def __init__(self, chain, max_workers=5):
self.chain = chain
self.queue = asyncio.Queue()
self.max_workers = max_workers
async def worker(self, worker_id):
"""工作线程"""
while True:
input_data = await self.queue.get()
if input_data is None: # 结束信号
break
try:
result = await self.chain.ainvoke(input_data)
print(f"[Worker {worker_id}] 完成:{result}")
except Exception as e:
print(f"[Worker {worker_id}] 失败:{e}")
self.queue.task_done()
async def process(self, inputs: list):
"""处理任务"""
# 启动工作线程
workers = [
asyncio.create_task(self.worker(i))
for i in range(self.max_workers)
]
# 添加任务到队列
for input_data in inputs:
await self.queue.put(input_data)
# 等待所有任务完成
await self.queue.join()
# 停止工作线程
for _ in range(self.max_workers):
await self.queue.put(None)
# 等待工作线程退出
await asyncio.gather(*workers)
# 使用
async def use_task_queue():
remote_chain = RemoteRunnable("http://localhost:8000/chat")
queue = AsyncTaskQueue(remote_chain, max_workers=3)
inputs = [{"message": f"问题{i}"} for i in range(10)]
await queue.process(inputs)
asyncio.run(use_task_queue())
# 3. 速率限制
class RateLimiter:
"""速率限制器"""
def __init__(self, rate: int, per: float):
"""
rate: 允许的请求数
per: 时间窗口(秒)
"""
self.rate = rate
self.per = per
self.allowance = rate
self.last_check = asyncio.get_event_loop().time()
async def acquire(self):
"""获取许可"""
current = asyncio.get_event_loop().time()
time_passed = current - self.last_check
self.last_check = current
self.allowance += time_passed * (self.rate / self.per)
if self.allowance > self.rate:
self.allowance = self.rate
if self.allowance < 1.0:
sleep_time = (1.0 - self.allowance) * (self.per / self.rate)
await asyncio.sleep(sleep_time)
self.allowance = 0.0
else:
self.allowance -= 1.0
# 使用
async def rate_limited_calls():
"""速率限制的调用"""
remote_chain = RemoteRunnable("http://localhost:8000/chat")
limiter = RateLimiter(rate=5, per=1.0) # 每秒5个请求
for i in range(20):
await limiter.acquire()
result = await remote_chain.ainvoke({"message": f"问题{i}"})
print(f"完成{i}")
asyncio.run(rate_limited_calls())
# 4. 动态并发调整
class AdaptiveConcurrencyController:
"""自适应并发控制器"""
def __init__(self, initial_concurrency=5, min_concurrency=1, max_concurrency=20):
self.concurrency = initial_concurrency
self.min = min_concurrency
self.max = max_concurrency
self.success_count = 0
self.error_count = 0
def adjust(self):
"""调整并发数"""
total = self.success_count + self.error_count
if total < 10:
return
error_rate = self.error_count / total
if error_rate > 0.1: # 错误率超过10%
self.concurrency = max(self.min, self.concurrency - 1)
elif error_rate < 0.01: # 错误率低于1%
self.concurrency = min(self.max, self.concurrency + 1)
# 重置计数
self.success_count = 0
self.error_count = 0
async def invoke(self, chain, input_data):
"""调用并统计"""
try:
result = await chain.ainvoke(input_data)
self.success_count += 1
return result
except Exception as e:
self.error_count += 1
raise
finally:
self.adjust()
# 5. 批量处理with并发控制
async def batch_with_concurrency_control(inputs: list, max_concurrent=10):
"""批量处理with并发控制"""
remote_chain = RemoteRunnable("http://localhost:8000/chat")
semaphore = asyncio.Semaphore(max_concurrent)
async def process_one(input_data, index):
async with semaphore:
try:
result = await remote_chain.ainvoke(input_data)
return {"index": index, "success": True, "result": result}
except Exception as e:
return {"index": index, "success": False, "error": str(e)}
tasks = [process_one(inp, i) for i, inp in enumerate(inputs)]
results = await asyncio.gather(*tasks)
# 统计
success = sum(1 for r in results if r["success"])
failed = len(results) - success
print(f"成功:{success},失败:{failed}")
return results
# 6. 超时批量处理
async def batch_with_timeout(inputs: list, timeout=30):
"""批量处理with超时"""
remote_chain = RemoteRunnable("http://localhost:8000/chat")
async def invoke_with_timeout(input_data):
try:
return await asyncio.wait_for(
remote_chain.ainvoke(input_data),
timeout=timeout
)
except asyncio.TimeoutError:
return {"error": "timeout"}
tasks = [invoke_with_timeout(inp) for inp in inputs]
results = await asyncio.gather(*tasks, return_exceptions=True)
return results
# 7. 优先级队列
import heapq
class PriorityAsyncQueue:
"""优先级异步队列"""
def __init__(self):
self.heap = []
self.counter = 0
async def put(self, item, priority=0):
"""添加任务(priority越小越优先)"""
heapq.heappush(self.heap, (priority, self.counter, item))
self.counter += 1
async def get(self):
"""获取任务"""
if self.heap:
_, _, item = heapq.heappop(self.heap)
return item
return None
# 8. 并发监控
class ConcurrencyMonitor:
"""并发监控器"""
def __init__(self):
self.active = 0
self.peak = 0
self.total = 0
async def track(self, coro):
"""跟踪协程"""
self.active += 1
self.total += 1
self.peak = max(self.peak, self.active)
try:
return await coro
finally:
self.active -= 1
def stats(self):
"""统计信息"""
return {
"active": self.active,
"peak": self.peak,
"total": self.total
}
# 使用
async def monitored_concurrent_calls():
remote_chain = RemoteRunnable("http://localhost:8000/chat")
monitor = ConcurrencyMonitor()
tasks = [
monitor.track(remote_chain.ainvoke({"message": f"问题{i}"}))
for i in range(20)
]
await asyncio.gather(*tasks)
print(f"并发统计:{monitor.stats()}")
asyncio.run(monitored_concurrent_calls())
# 9. 信创环境并发控制
async def xinchuang_concurrent_calls(questions: list, max_concurrent=5):
"""信创环境并发调用"""
local_chain = RemoteRunnable(
"http://internal-server:8000/local/chat",
verify=False
)
semaphore = asyncio.Semaphore(max_concurrent)
async def process_question(question, index):
async with semaphore:
try:
result = await local_chain.ainvoke({
"question": question,
"llm_provider": "ollama"
})
print(f"[{index}] 完成")
return result
except Exception as e:
print(f"[{index}] 失败:{e}")
return None
tasks = [process_question(q, i) for i, q in enumerate(questions)]
results = await asyncio.gather(*tasks)
return results
# 使用
questions = [f"问题{i}" for i in range(10)]
results = asyncio.run(xinchuang_concurrent_calls(questions, max_concurrent=3))
print("✓ 并发控制配置完成")
---
02.集成应用
a.FastAPI集成
a.功能说明
在FastAPI应用中集成RemoteRunnable异步调用。利用FastAPI的异步支持提升性能。实现高并发的API代理或网关。FastAPI集成是构建生产级AI服务的常见模式。
b.代码示例
---
from fastapi import FastAPI, HTTPException
from langserve import RemoteRunnable
from pydantic import BaseModel
app = FastAPI()
# 1. 基础集成
remote_chain = RemoteRunnable("http://localhost:8000/chat")
class ChatRequest(BaseModel):
message: str
@app.post("/chat")
async def chat(request: ChatRequest):
"""聊天端点"""
result = await remote_chain.ainvoke({"message": request.message})
return result
# 2. 带错误处理
@app.post("/safe_chat")
async def safe_chat(request: ChatRequest):
"""安全聊天端点"""
try:
result = await remote_chain.ainvoke({"message": request.message})
return {"success": True, "data": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 3. 多个Chain代理
translate_chain = RemoteRunnable("http://localhost:8000/translate")
summary_chain = RemoteRunnable("http://localhost:8000/summary")
@app.post("/translate")
async def translate(text: str, target_lang: str):
"""翻译端点"""
result = await translate_chain.ainvoke({
"text": text,
"target_lang": target_lang
})
return result
@app.post("/summary")
async def summary(content: str):
"""摘要端点"""
result = await summary_chain.ainvoke({"content": content})
return result
# 4. 流式代理
from fastapi.responses import StreamingResponse
@app.post("/stream_chat")
async def stream_chat(request: ChatRequest):
"""流式聊天端点"""
async def generate():
async for chunk in remote_chain.astream({"message": request.message}):
yield f"data: {chunk}\\n\\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)
# 5. 批量处理端点
class BatchRequest(BaseModel):
messages: list[str]
@app.post("/batch_chat")
async def batch_chat(request: BatchRequest):
"""批量聊天"""
tasks = [
remote_chain.ainvoke({"message": msg})
for msg in request.messages
]
results = await asyncio.gather(*tasks, return_exceptions=True)
return {
"results": [
{"success": not isinstance(r, Exception), "data": r}
for r in results
]
}
# 6. 带超时
import asyncio
@app.post("/timeout_chat")
async def timeout_chat(request: ChatRequest, timeout: float = 30):
"""带超时的聊天"""
try:
result = await asyncio.wait_for(
remote_chain.ainvoke({"message": request.message}),
timeout=timeout
)
return result
except asyncio.TimeoutError:
raise HTTPException(status_code=408, detail="请求超时")
# 7. 健康检查
@app.get("/health")
async def health_check():
"""健康检查"""
try:
await remote_chain.ainvoke({"message": "健康检查"})
return {"status": "healthy"}
except Exception as e:
return {"status": "unhealthy", "error": str(e)}
# 8. 信创环境FastAPI集成
xc_local_chain = RemoteRunnable(
"http://internal-server:8000/local/chat",
verify=False
)
@app.post("/xinchuang/chat")
async def xinchuang_chat(question: str):
"""信创聊天端点"""
try:
result = await xc_local_chain.ainvoke({
"question": question,
"llm_provider": "ollama"
})
return {
"code": 0,
"message": "成功",
"data": result
}
except Exception as e:
return {
"code": -1,
"message": str(e),
"data": None
}
print("✓ FastAPI集成配置完成")
---
3.4 批量调用
01.batch方法
a.同步批量
a.功能说明
batch方法一次发送多个输入,返回多个输出。服务器端并发处理,比多次invoke更高效。适用于批量数据处理场景。batch是提升吞吐量的关键方法。
b.代码示例
---
from langserve import RemoteRunnable
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 1. 基础批量调用
inputs = [
{"message": "你好"},
{"message": "再见"},
{"message": "谢谢"}
]
results = remote_chain.batch(inputs)
for i, result in enumerate(results):
print(f"输入{i}: {inputs[i]}")
print(f"输出{i}: {result}")
print()
# 2. 批量翻译
translate_chain = RemoteRunnable("http://localhost:8000/translate")
texts = ["Hello", "Good morning", "Thank you"]
inputs = [{"text": t, "target_lang": "zh"} for t in texts]
translations = translate_chain.batch(inputs)
for orig, trans in zip(texts, translations):
print(f"{orig} -> {trans}")
# 3. 大批量处理
def batch_process_large(chain, all_inputs, batch_size=100):
"""分批处理大量输入"""
results = []
for i in range(0, len(all_inputs), batch_size):
batch = all_inputs[i:i+batch_size]
print(f"处理批次{i//batch_size + 1},共{len(batch)}项...")
batch_results = chain.batch(batch)
results.extend(batch_results)
return results
# 使用
large_inputs = [{"message": f"问题{i}"} for i in range(500)]
all_results = batch_process_large(remote_chain, large_inputs, batch_size=50)
print(f"总共处理{len(all_results)}项")
# 4. 批量with错误处理
def safe_batch(chain, inputs):
"""安全的批量调用"""
try:
results = chain.batch(inputs)
return {
"success": True,
"results": results,
"total": len(results)
}
except Exception as e:
return {
"success": False,
"error": str(e),
"total": 0
}
# 使用
result = safe_batch(remote_chain, inputs)
if result["success"]:
print(f"成功处理{result['total']}项")
else:
print(f"批量处理失败:{result['error']}")
# 5. 配置批量行为
# batch方法支持config参数
results = remote_chain.batch(
inputs,
config={
"max_concurrency": 10 # 服务器端并发数
}
)
# 6. 结果验证
def validate_batch_results(inputs, results):
"""验证批量结果"""
if len(inputs) != len(results):
print(f"警告:输入{len(inputs)}个,输出{len(results)}个")
valid_count = 0
for i, result in enumerate(results):
if result and isinstance(result, dict):
valid_count += 1
else:
print(f"结果{i}无效:{result}")
print(f"有效结果:{valid_count}/{len(results)}")
return valid_count == len(results)
# 使用
results = remote_chain.batch(inputs)
validate_batch_results(inputs, results)
# 7. 性能对比
import time
def benchmark_batch_vs_sequential():
"""对比批量vs顺序调用性能"""
inputs = [{"message": f"问题{i}"} for i in range(20)]
# 顺序调用
start = time.time()
seq_results = []
for inp in inputs:
result = remote_chain.invoke(inp)
seq_results.append(result)
seq_time = time.time() - start
# 批量调用
start = time.time()
batch_results = remote_chain.batch(inputs)
batch_time = time.time() - start
print(f"顺序调用:{seq_time:.2f}秒")
print(f"批量调用:{batch_time:.2f}秒")
print(f"性能提升:{(seq_time/batch_time):.2f}倍")
# 运行测试
# benchmark_batch_vs_sequential()
# 8. 批量结果处理
def process_batch_results(results):
"""处理批量结果"""
processed = []
for i, result in enumerate(results):
content = result.get("content", str(result))
processed.append({
"index": i,
"content": content,
"length": len(content),
"has_error": "error" in result
})
return processed
# 使用
results = remote_chain.batch(inputs)
processed = process_batch_results(results)
for item in processed:
print(f"[{item['index']}] 长度{item['length']}")
# 9. 信创环境批量调用
xc_chain = RemoteRunnable("http://internal-server:8000/local/chat")
questions = [f"问题{i}" for i in range(10)]
inputs = [{"question": q, "llm_provider": "ollama"} for q in questions]
results = xc_chain.batch(inputs)
print(f"批量处理{len(results)}个问题")
for i, result in enumerate(results):
print(f"Q{i}: {questions[i]}")
print(f"A{i}: {result}")
print()
# 10. 批量重试
def batch_with_retry(chain, inputs, max_retries=3):
"""批量调用with重试"""
for attempt in range(max_retries):
try:
return chain.batch(inputs)
except Exception as e:
if attempt < max_retries - 1:
print(f"批量调用失败,重试{attempt+1}/{max_retries-1}...")
time.sleep(2)
else:
raise
# 使用
results = batch_with_retry(remote_chain, inputs)
print("✓ batch方法配置完成")
---
b.异步批量
a.功能说明
abatch方法是batch的异步版本,支持异步批量处理。结合asyncio实现高并发批量调用。适用于大规模数据处理和高吞吐量场景。异步批量是性能优化的高级技巧。
b.代码示例
---
from langserve import RemoteRunnable
import asyncio
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 1. 基础异步批量
async def basic_async_batch():
"""基础异步批量调用"""
inputs = [
{"message": "你好"},
{"message": "再见"},
{"message": "谢谢"}
]
results = await remote_chain.abatch(inputs)
for i, result in enumerate(results):
print(f"结果{i}: {result}")
asyncio.run(basic_async_batch())
# 2. 大批量异步处理
async def async_batch_large(inputs, batch_size=100):
"""分批异步处理"""
results = []
for i in range(0, len(inputs), batch_size):
batch = inputs[i:i+batch_size]
print(f"异步处理批次{i//batch_size + 1}...")
batch_results = await remote_chain.abatch(batch)
results.extend(batch_results)
return results
# 使用
async def process_large():
inputs = [{"message": f"问题{i}"} for i in range(500)]
results = await async_batch_large(inputs, batch_size=50)
print(f"完成{len(results)}项")
asyncio.run(process_large())
# 3. 并发多个批量调用
async def concurrent_batch_calls():
"""并发多个批量调用"""
inputs1 = [{"message": f"组1-{i}"} for i in range(10)]
inputs2 = [{"message": f"组2-{i}"} for i in range(10)]
inputs3 = [{"message": f"组3-{i}"} for i in range(10)]
# 并发执行多个batch
results = await asyncio.gather(
remote_chain.abatch(inputs1),
remote_chain.abatch(inputs2),
remote_chain.abatch(inputs3)
)
print(f"组1完成{len(results[0])}项")
print(f"组2完成{len(results[1])}项")
print(f"组3完成{len(results[2])}项")
asyncio.run(concurrent_batch_calls())
# 4. 异步批量with超时
async def async_batch_with_timeout():
"""带超时的异步批量"""
inputs = [{"message": f"问题{i}"} for i in range(20)]
try:
results = await asyncio.wait_for(
remote_chain.abatch(inputs),
timeout=30.0
)
print(f"完成{len(results)}项")
except asyncio.TimeoutError:
print("批量调用超时")
asyncio.run(async_batch_with_timeout())
# 5. 异步批量错误处理
async def safe_async_batch(inputs):
"""安全的异步批量调用"""
try:
results = await remote_chain.abatch(inputs)
return {
"success": True,
"results": results
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
# 使用
async def test_safe_batch():
inputs = [{"message": f"测试{i}"} for i in range(10)]
result = await safe_async_batch(inputs)
if result["success"]:
print(f"成功处理{len(result['results'])}项")
else:
print(f"处理失败:{result['error']}")
asyncio.run(test_safe_batch())
# 6. 流式批量处理
async def streaming_batch_process(all_inputs, chunk_size=10):
"""流式批量处理"""
for i in range(0, len(all_inputs), chunk_size):
chunk = all_inputs[i:i+chunk_size]
print(f"处理chunk {i//chunk_size + 1}...")
results = await remote_chain.abatch(chunk)
# 立即处理结果
for result in results:
yield result
# 使用
async def use_streaming_batch():
inputs = [{"message": f"问题{i}"} for i in range(100)]
async for result in streaming_batch_process(inputs, chunk_size=10):
print(f"收到结果:{result}")
asyncio.run(use_streaming_batch())
# 7. 动态批量大小
async def adaptive_batch_process(inputs):
"""自适应批量大小处理"""
batch_size = 50
success_count = 0
error_count = 0
for i in range(0, len(inputs), batch_size):
batch = inputs[i:i+batch_size]
try:
results = await remote_chain.abatch(batch)
success_count += len(results)
# 成功,增加批量大小
batch_size = min(batch_size + 10, 100)
except Exception as e:
error_count += 1
print(f"批次失败:{e}")
# 失败,减小批量大小
batch_size = max(batch_size - 10, 10)
print(f"成功:{success_count},失败批次:{error_count}")
# 8. 进度跟踪
async def batch_with_progress(inputs, batch_size=50):
"""带进度跟踪的批量处理"""
total = len(inputs)
completed = 0
for i in range(0, total, batch_size):
batch = inputs[i:i+batch_size]
results = await remote_chain.abatch(batch)
completed += len(results)
progress = (completed / total) * 100
print(f"进度:{completed}/{total} ({progress:.1f}%)")
print("完成!")
# 使用
async def test_progress():
inputs = [{"message": f"问题{i}"} for i in range(200)]
await batch_with_progress(inputs, batch_size=20)
asyncio.run(test_progress())
# 9. 信创环境异步批量
async def xinchuang_async_batch():
"""信创环境异步批量调用"""
xc_chain = RemoteRunnable(
"http://internal-server:8000/local/chat",
verify=False
)
questions = [f"信创问题{i}" for i in range(20)]
inputs = [
{"question": q, "llm_provider": "ollama"}
for q in questions
]
results = await xc_chain.abatch(inputs)
print(f"批量处理{len(results)}个问题")
return results
asyncio.run(xinchuang_async_batch())
# 10. 完整异步批量应用
class AsyncBatchProcessor:
"""异步批量处理器"""
def __init__(self, chain, batch_size=50, max_retries=3):
self.chain = chain
self.batch_size = batch_size
self.max_retries = max_retries
async def process(self, inputs):
"""处理所有输入"""
results = []
total = len(inputs)
for i in range(0, total, self.batch_size):
batch = inputs[i:i+self.batch_size]
# 重试逻辑
for attempt in range(self.max_retries):
try:
batch_results = await self.chain.abatch(batch)
results.extend(batch_results)
progress = (len(results) / total) * 100
print(f"进度:{len(results)}/{total} ({progress:.1f}%)")
break # 成功,退出重试
except Exception as e:
if attempt < self.max_retries - 1:
print(f"批次失败,重试...")
await asyncio.sleep(2 ** attempt)
else:
print(f"批次最终失败:{e}")
# 添加空结果
results.extend([None] * len(batch))
return results
# 使用
async def use_processor():
processor = AsyncBatchProcessor(
remote_chain,
batch_size=30,
max_retries=3
)
inputs = [{"message": f"问题{i}"} for i in range(150)]
results = await processor.process(inputs)
valid = sum(1 for r in results if r is not None)
print(f"有效结果:{valid}/{len(results)}")
asyncio.run(use_processor())
print("✓ 异步批量配置完成")
---
02.最佳实践
a.批量大小选择
a.功能说明
选择合适的批量大小平衡性能和资源。过小影响吞吐量,过大可能超时或内存不足。根据输入复杂度和服务器性能调整。批量大小是性能调优的关键参数。
b.代码示例
---
from langserve import RemoteRunnable
import time
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 1. 基准测试不同批量大小
def benchmark_batch_sizes(total_inputs=100):
"""测试不同批量大小的性能"""
inputs = [{"message": f"问题{i}"} for i in range(total_inputs)]
batch_sizes = [10, 20, 50, 100]
results = {}
for batch_size in batch_sizes:
start = time.time()
all_results = []
for i in range(0, len(inputs), batch_size):
batch = inputs[i:i+batch_size]
batch_results = remote_chain.batch(batch)
all_results.extend(batch_results)
elapsed = time.time() - start
results[batch_size] = {
"time": elapsed,
"throughput": total_inputs / elapsed
}
print(f"批量大小{batch_size}: {elapsed:.2f}秒, {results[batch_size]['throughput']:.2f}个/秒")
# 找最优批量大小
best_size = max(results, key=lambda k: results[k]['throughput'])
print(f"\\n最优批量大小:{best_size}")
return results
# 运行测试
# benchmark_batch_sizes()
# 2. 动态批量大小
class DynamicBatchSizer:
"""动态批量大小管理器"""
def __init__(self, initial_size=50, min_size=10, max_size=100):
self.size = initial_size
self.min_size = min_size
self.max_size = max_size
self.avg_time = None
def adjust(self, batch_time, batch_size):
"""根据执行时间调整批量大小"""
# 更新平均时间
if self.avg_time is None:
self.avg_time = batch_time
else:
self.avg_time = 0.7 * self.avg_time + 0.3 * batch_time
# 目标:每批次1-2秒
if self.avg_time < 1.0:
# 太快,增加批量
self.size = min(int(self.size * 1.5), self.max_size)
elif self.avg_time > 2.0:
# 太慢,减小批量
self.size = max(int(self.size * 0.7), self.min_size)
return self.size
# 使用
sizer = DynamicBatchSizer()
inputs = [{"message": f"问题{i}"} for i in range(200)]
for i in range(0, len(inputs), sizer.size):
batch = inputs[i:i+sizer.size]
start = time.time()
results = remote_chain.batch(batch)
elapsed = time.time() - start
new_size = sizer.adjust(elapsed, len(batch))
print(f"批量{len(batch)}项,耗时{elapsed:.2f}秒,下次批量{new_size}项")
# 3. 基于输入大小调整
def adaptive_batch_by_input_size(inputs):
"""根据输入大小自适应批量"""
# 估算输入大小
avg_input_size = sum(len(str(inp)) for inp in inputs) / len(inputs)
# 根据输入大小选择批量
if avg_input_size < 100:
batch_size = 100 # 小输入,大批量
elif avg_input_size < 500:
batch_size = 50
elif avg_input_size < 1000:
batch_size = 20
else:
batch_size = 10 # 大输入,小批量
print(f"平均输入大小:{avg_input_size:.0f}字符,批量大小:{batch_size}")
return batch_size
# 使用
inputs = [{"message": "短" * i} for i in range(1, 101)]
batch_size = adaptive_batch_by_input_size(inputs)
# 4. 内存限制批量
import sys
def batch_with_memory_limit(inputs, max_memory_mb=100):
"""根据内存限制批量处理"""
batch = []
current_size = 0
max_size = max_memory_mb * 1024 * 1024
for inp in inputs:
inp_size = sys.getsizeof(str(inp))
if current_size + inp_size > max_size and batch:
# 达到内存限制,处理当前批次
print(f"处理批次({len(batch)}项,{current_size/1024/1024:.2f}MB)...")
results = remote_chain.batch(batch)
# 重置
batch = []
current_size = 0
batch.append(inp)
current_size += inp_size
# 处理剩余
if batch:
print(f"处理最后批次({len(batch)}项,{current_size/1024/1024:.2f}MB)...")
results = remote_chain.batch(batch)
# 5. 超时控制批量
def batch_with_timeout_control(inputs, target_time=2.0):
"""根据目标时间控制批量大小"""
batch_size = 50 # 初始批量
for i in range(0, len(inputs), batch_size):
batch = inputs[i:i+batch_size]
start = time.time()
results = remote_chain.batch(batch)
elapsed = time.time() - start
print(f"批次耗时{elapsed:.2f}秒")
# 调整批量大小以接近目标时间
if elapsed > 0:
ratio = target_time / elapsed
batch_size = int(batch_size * ratio)
batch_size = max(10, min(batch_size, 100))
print(f"下次批量:{batch_size}")
# 6. 并发级别优化
def batch_with_concurrency(inputs, test_concurrency_levels=[1, 5, 10]):
"""测试不同并发级别"""
for level in test_concurrency_levels:
start = time.time()
results = remote_chain.batch(
inputs,
config={"max_concurrency": level}
)
elapsed = time.time() - start
print(f"并发级别{level}: {elapsed:.2f}秒")
# 7. 批量重试策略
def batch_with_smart_retry(inputs, max_retries=3):
"""智能批量重试"""
batch_size = 50
for i in range(0, len(inputs), batch_size):
batch = inputs[i:i+batch_size]
for attempt in range(max_retries):
try:
results = remote_chain.batch(batch)
break
except Exception as e:
if "timeout" in str(e).lower():
# 超时,减小批量
batch_size = max(10, batch_size // 2)
print(f"超时,减小批量到{batch_size}")
# 重新分批
half = len(batch) // 2
batch = batch[:half]
elif attempt < max_retries - 1:
print(f"失败,重试...")
time.sleep(2)
else:
raise
# 8. 信创环境批量优化
def xinchuang_optimal_batch(inputs):
"""信创环境最优批量"""
xc_chain = RemoteRunnable("http://internal-server:8000/local/chat")
# 本地Ollama性能有限,使用较小批量
batch_size = 20
results = []
for i in range(0, len(inputs), batch_size):
batch = inputs[i:i+batch_size]
print(f"处理批次{i//batch_size + 1}({len(batch)}项)...")
try:
batch_results = xc_chain.batch(batch)
results.extend(batch_results)
except Exception as e:
print(f"批次失败:{e}")
# 逐个处理失败的批次
for inp in batch:
try:
result = xc_chain.invoke(inp)
results.append(result)
except:
results.append(None)
return results
print("✓ 批量大小选择配置完成")
---
3.5 流式调用
01.stream方法
a.同步流式
a.功能说明
stream方法以流式方式接收响应,边生成边返回。使用Python生成器逐块处理数据。适用于长文本生成,提升用户体验。流式调用是实时交互的关键技术。
b.代码示例
---
from langserve import RemoteRunnable
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 1. 基础流式调用
for chunk in remote_chain.stream({"message": "讲一个长故事"}):
print(chunk, end="", flush=True)
print() # 换行
# 2. 收集流式输出
def collect_stream(chain, input_data):
"""收集所有流式输出"""
full_text = ""
for chunk in chain.stream(input_data):
# 提取内容
if isinstance(chunk, dict) and "content" in chunk:
content = chunk["content"]
else:
content = str(chunk)
print(content, end="", flush=True)
full_text += content
print() # 换行
return full_text
# 使用
result = collect_stream(remote_chain, {"message": "写一首诗"})
print(f"\\n总长度:{len(result)}字符")
# 3. 流式with进度
def stream_with_progress(chain, input_data):
"""流式调用with进度显示"""
char_count = 0
for chunk in chain.stream(input_data):
content = str(chunk)
print(content, end="", flush=True)
char_count += len(content)
# 每100字符显示进度
if char_count % 100 == 0:
print(f"\\n[已生成{char_count}字符]\\n", end="", flush=True)
print(f"\\n\\n总计:{char_count}字符")
# 使用
stream_with_progress(remote_chain, {"message": "详细解释AI"})
# 4. 流式错误处理
def safe_stream(chain, input_data):
"""安全的流式调用"""
try:
for chunk in chain.stream(input_data):
print(chunk, end="", flush=True)
print()
except Exception as e:
print(f"\\n[流式调用失败:{e}]")
# 使用
safe_stream(remote_chain, {"message": "测试"})
# 5. 流式保存到文件
def stream_to_file(chain, input_data, output_file):
"""流式输出保存到文件"""
with open(output_file, 'w', encoding='utf-8') as f:
for chunk in chain.stream(input_data):
content = str(chunk)
# 同时显示和保存
print(content, end="", flush=True)
f.write(content)
f.flush() # 立即写入
print(f"\\n输出已保存到:{output_file}")
# 使用
stream_to_file(
remote_chain,
{"message": "写一篇文章"},
"/tmp/output.txt"
)
# 6. 流式超时处理
import signal
import time
class StreamTimeout(Exception):
pass
def stream_with_timeout(chain, input_data, timeout=30):
"""带超时的流式调用"""
def timeout_handler(signum, frame):
raise StreamTimeout("流式调用超时")
# 设置超时(仅Unix系统)
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
try:
for chunk in chain.stream(input_data):
print(chunk, end="", flush=True)
signal.alarm(0) # 取消超时
print()
except StreamTimeout:
print("\\n[超时]")
finally:
signal.alarm(0)
# 7. 流式统计
class StreamStats:
"""流式统计"""
def __init__(self):
self.start_time = None
self.chunk_count = 0
self.char_count = 0
def stream_with_stats(self, chain, input_data):
"""流式调用with统计"""
self.start_time = time.time()
self.chunk_count = 0
self.char_count = 0
for chunk in chain.stream(input_data):
content = str(chunk)
print(content, end="", flush=True)
self.chunk_count += 1
self.char_count += len(content)
elapsed = time.time() - self.start_time
print(f"\\n\\n统计信息:")
print(f" Chunk数:{self.chunk_count}")
print(f" 字符数:{self.char_count}")
print(f" 耗时:{elapsed:.2f}秒")
print(f" 速度:{self.char_count/elapsed:.1f}字符/秒")
# 使用
stats = StreamStats()
stats.stream_with_stats(remote_chain, {"message": "长文本生成"})
# 8. 流式缓冲
class StreamBuffer:
"""流式缓冲器"""
def __init__(self, buffer_size=10):
self.buffer = []
self.buffer_size = buffer_size
def stream_buffered(self, chain, input_data):
"""缓冲流式输出"""
for chunk in chain.stream(input_data):
self.buffer.append(chunk)
# 缓冲区满,批量处理
if len(self.buffer) >= self.buffer_size:
self.flush()
# 处理剩余
if self.buffer:
self.flush()
def flush(self):
"""刷新缓冲区"""
combined = "".join(str(c) for c in self.buffer)
print(combined, end="", flush=True)
self.buffer = []
# 使用
buffer = StreamBuffer(buffer_size=5)
buffer.stream_buffered(remote_chain, {"message": "测试缓冲"})
print()
# 9. 信创环境流式调用
xc_chain = RemoteRunnable("http://internal-server:8000/local/chat")
print("问:介绍信创技术")
print("答:", end="")
for chunk in xc_chain.stream({"question": "介绍信创技术"}):
content = str(chunk)
print(content, end="", flush=True)
print("\\n")
# 10. 流式重试
def stream_with_retry(chain, input_data, max_retries=3):
"""流式调用with重试"""
for attempt in range(max_retries):
try:
for chunk in chain.stream(input_data):
print(chunk, end="", flush=True)
print()
break # 成功,退出
except Exception as e:
if attempt < max_retries - 1:
print(f"\\n[失败,重试...]")
time.sleep(2)
else:
print(f"\\n[所有重试失败:{e}]")
# 使用
stream_with_retry(remote_chain, {"message": "测试"})
print("✓ stream方法配置完成")
---
b.异步流式
a.功能说明
astream方法是stream的异步版本,支持异步流式处理。使用async for循环接收流式数据。结合asyncio实现高性能异步流式应用。异步流式是现代Web应用的核心技术。
b.代码示例
---
from langserve import RemoteRunnable
import asyncio
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 1. 基础异步流式
async def basic_async_stream():
"""基础异步流式调用"""
async for chunk in remote_chain.astream({"message": "讲故事"}):
print(chunk, end="", flush=True)
print()
asyncio.run(basic_async_stream())
# 2. 异步流式收集
async def collect_async_stream(chain, input_data):
"""收集异步流式输出"""
full_text = ""
async for chunk in chain.astream(input_data):
content = str(chunk)
print(content, end="", flush=True)
full_text += content
print()
return full_text
# 使用
async def use_collect():
result = await collect_async_stream(
remote_chain,
{"message": "写诗"}
)
print(f"总长度:{len(result)}")
asyncio.run(use_collect())
# 3. 并发异步流式
async def concurrent_async_streams():
"""并发多个异步流式调用"""
async def stream_one(message, stream_id):
"""单个流式"""
print(f"\\n[Stream {stream_id}开始]")
async for chunk in remote_chain.astream({"message": message}):
print(f"[{stream_id}] {chunk}", end="", flush=True)
print(f"\\n[Stream {stream_id}结束]")
# 并发3个流式调用
await asyncio.gather(
stream_one("故事1", 1),
stream_one("故事2", 2),
stream_one("故事3", 3)
)
asyncio.run(concurrent_async_streams())
# 4. 异步流式with超时
async def async_stream_with_timeout():
"""带超时的异步流式"""
async def stream_task():
async for chunk in remote_chain.astream({"message": "长文本"}):
print(chunk, end="", flush=True)
try:
await asyncio.wait_for(stream_task(), timeout=30.0)
print()
except asyncio.TimeoutError:
print("\\n[流式超时]")
asyncio.run(async_stream_with_timeout())
# 5. 异步流式错误处理
async def safe_async_stream(chain, input_data):
"""安全的异步流式"""
try:
async for chunk in chain.astream(input_data):
print(chunk, end="", flush=True)
print()
except Exception as e:
print(f"\\n[异步流式失败:{e}]")
# 使用
asyncio.run(safe_async_stream(remote_chain, {"message": "测试"}))
# 6. 异步流式队列
class AsyncStreamQueue:
"""异步流式队列"""
def __init__(self):
self.queue = asyncio.Queue()
async def stream_to_queue(self, chain, input_data):
"""流式输出到队列"""
async for chunk in chain.astream(input_data):
await self.queue.put(chunk)
await self.queue.put(None) # 结束标记
async def consume_queue(self):
"""从队列消费"""
while True:
chunk = await self.queue.get()
if chunk is None:
break
print(chunk, end="", flush=True)
print()
# 使用
async def use_stream_queue():
queue = AsyncStreamQueue()
# 生产者和消费者并发运行
await asyncio.gather(
queue.stream_to_queue(remote_chain, {"message": "测试"}),
queue.consume_queue()
)
asyncio.run(use_stream_queue())
# 7. 异步流式批量处理
async def async_stream_batch(inputs: list):
"""批量异步流式处理"""
async def stream_one(input_data, index):
"""处理单个输入"""
print(f"\\n[输入{index}]")
async for chunk in remote_chain.astream(input_data):
print(f"[{index}] {chunk}", end="", flush=True)
print()
# 并发处理所有输入
tasks = [
stream_one(inp, i)
for i, inp in enumerate(inputs)
]
await asyncio.gather(*tasks)
# 使用
inputs = [{"message": f"问题{i}"} for i in range(3)]
asyncio.run(async_stream_batch(inputs))
# 8. FastAPI流式端点
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
app = FastAPI()
@app.post("/stream_chat")
async def stream_chat(message: str):
"""流式聊天端点"""
async def generate():
"""SSE生成器"""
async for chunk in remote_chain.astream({"message": message}):
# 格式化为SSE
yield f"data: {chunk}\\n\\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)
# 9. 异步流式统计
class AsyncStreamStats:
"""异步流式统计"""
def __init__(self):
self.start_time = None
self.chunk_count = 0
self.char_count = 0
async def stream_with_stats(self, chain, input_data):
"""异步流式with统计"""
import time
self.start_time = time.time()
self.chunk_count = 0
self.char_count = 0
async for chunk in chain.astream(input_data):
content = str(chunk)
print(content, end="", flush=True)
self.chunk_count += 1
self.char_count += len(content)
elapsed = time.time() - self.start_time
print(f"\\n\\n统计:")
print(f" Chunk: {self.chunk_count}")
print(f" 字符: {self.char_count}")
print(f" 耗时: {elapsed:.2f}秒")
# 使用
async def use_stats():
stats = AsyncStreamStats()
await stats.stream_with_stats(
remote_chain,
{"message": "长文本"}
)
asyncio.run(use_stats())
# 10. 信创环境异步流式
async def xinchuang_async_stream():
"""信创环境异步流式"""
xc_chain = RemoteRunnable(
"http://internal-server:8000/local/chat",
verify=False
)
print("问:介绍信创技术")
print("答:", end="")
async for chunk in xc_chain.astream({
"question": "介绍信创技术",
"llm_provider": "ollama"
}):
print(chunk, end="", flush=True)
print("\\n")
asyncio.run(xinchuang_async_stream())
# 11. 完整异步流式应用
class AsyncStreamApp:
"""异步流式应用"""
def __init__(self, url: str):
self.chain = RemoteRunnable(url)
async def chat_stream(self, message: str):
"""流式聊天"""
full_response = ""
async for chunk in self.chain.astream({"message": message}):
content = str(chunk)
print(content, end="", flush=True)
full_response += content
print()
return full_response
async def multi_chat_stream(self, messages: list):
"""多个流式聊天"""
tasks = [
self.chat_stream(msg)
for msg in messages
]
results = await asyncio.gather(*tasks)
return results
# 使用
async def use_app():
app = AsyncStreamApp("http://localhost:8000/chat")
# 单个流式
await app.chat_stream("讲个故事")
# 多个流式
await app.multi_chat_stream([
"问题1",
"问题2",
"问题3"
])
# asyncio.run(use_app())
print("✓ 异步流式配置完成")
---
02.实时应用
a.WebSocket集成
a.功能说明
使用WebSocket实现双向实时通信,传输流式数据。客户端和服务器保持长连接,低延迟传输。适用于聊天、实时协作等场景。WebSocket是构建实时AI应用的标准技术。
b.代码示例
---
from fastapi import FastAPI, WebSocket
from langserve import RemoteRunnable
import asyncio
app = FastAPI()
remote_chain = RemoteRunnable("http://localhost:8000/chat")
# 1. 基础WebSocket流式
@app.websocket("/ws/chat")
async def websocket_chat(websocket: WebSocket):
"""WebSocket聊天端点"""
await websocket.accept()
try:
while True:
# 接收消息
message = await websocket.receive_text()
# 流式响应
async for chunk in remote_chain.astream({"message": message}):
await websocket.send_text(str(chunk))
# 发送结束标记
await websocket.send_text("[END]")
except Exception as e:
print(f"WebSocket错误:{e}")
finally:
await websocket.close()
# 2. WebSocket客户端(Python)
import websocket
def websocket_client():
"""WebSocket客户端"""
def on_message(ws, message):
"""收到消息"""
if message == "[END]":
print("\\n[完成]")
else:
print(message, end="", flush=True)
def on_error(ws, error):
"""错误"""
print(f"错误:{error}")
def on_close(ws, close_status_code, close_msg):
"""关闭"""
print("连接关闭")
def on_open(ws):
"""打开连接"""
print("连接已建立")
# 发送消息
ws.send("讲个故事")
ws = websocket.WebSocketApp(
"ws://localhost:8000/ws/chat",
on_open=on_open,
on_message=on_message,
on_error=on_error,
on_close=on_close
)
ws.run_forever()
# 使用
# websocket_client()
# 3. 异步WebSocket客户端
import websockets
async def async_websocket_client():
"""异步WebSocket客户端"""
uri = "ws://localhost:8000/ws/chat"
async with websockets.connect(uri) as websocket:
# 发送消息
await websocket.send("讲个故事")
# 接收流式响应
while True:
message = await websocket.recv()
if message == "[END]":
print("\\n[完成]")
break
print(message, end="", flush=True)
# 使用
# asyncio.run(async_websocket_client())
# 4. WebSocket广播
class ConnectionManager:
"""连接管理器"""
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
"""连接"""
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
"""断开"""
self.active_connections.remove(websocket)
async def broadcast(self, message: str):
"""广播消息"""
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@app.websocket("/ws/broadcast")
async def websocket_broadcast(websocket: WebSocket):
"""广播端点"""
await manager.connect(websocket)
try:
while True:
message = await websocket.receive_text()
# 流式响应并广播
async for chunk in remote_chain.astream({"message": message}):
await manager.broadcast(str(chunk))
await manager.broadcast("[END]")
except:
manager.disconnect(websocket)
# 5. WebSocket with认证
from fastapi import Header, HTTPException
async def verify_token(token: str):
"""验证Token"""
# 实际应从数据库验证
return token == "valid_token"
@app.websocket("/ws/secure")
async def secure_websocket(
websocket: WebSocket,
token: str = Header(None)
):
"""安全WebSocket"""
# 验证Token
if not await verify_token(token):
await websocket.close(code=1008) # Policy Violation
return
await websocket.accept()
# 正常处理...
# 6. WebSocket心跳
@app.websocket("/ws/heartbeat")
async def websocket_with_heartbeat(websocket: WebSocket):
"""带心跳的WebSocket"""
await websocket.accept()
async def send_heartbeat():
"""发送心跳"""
while True:
await asyncio.sleep(30)
await websocket.send_text("[PING]")
# 启动心跳任务
heartbeat_task = asyncio.create_task(send_heartbeat())
try:
while True:
message = await websocket.receive_text()
if message == "[PONG]":
continue
# 处理消息...
async for chunk in remote_chain.astream({"message": message}):
await websocket.send_text(str(chunk))
finally:
heartbeat_task.cancel()
await websocket.close()
# 7. 信创环境WebSocket
xc_chain = RemoteRunnable("http://internal-server:8000/local/chat")
@app.websocket("/ws/xinchuang")
async def xinchuang_websocket(websocket: WebSocket):
"""信创WebSocket端点"""
await websocket.accept()
try:
while True:
question = await websocket.receive_text()
# 使用本地Ollama流式响应
async for chunk in xc_chain.astream({
"question": question,
"llm_provider": "ollama"
}):
await websocket.send_text(str(chunk))
await websocket.send_text("[END]")
except Exception as e:
print(f"信创WebSocket错误:{e}")
finally:
await websocket.close()
print("✓ WebSocket集成配置完成")
---
4 高级配置
4.1 身份认证
01.认证方式
a.API Key认证
a.功能说明
使用API Key作为身份凭证,简单实用。客户端在请求头中携带Key,服务器验证后授权。适用于服务间调用、内部API。API Key是最常见的API认证方式。
b.代码示例
---
from fastapi import FastAPI, Header, HTTPException, Depends
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
app = FastAPI()
llm = ChatOpenAI()
# 1. 简单API Key验证
VALID_API_KEY = "my_secret_key_12345"
def verify_api_key(x_api_key: str = Header(None)):
"""验证API Key"""
if x_api_key != VALID_API_KEY:
raise HTTPException(
status_code=401,
detail="无效的API Key"
)
return x_api_key
# 自定义受保护端点
@app.post("/protected/chat")
def protected_chat(
message: str,
api_key: str = Depends(verify_api_key)
):
"""受保护的聊天端点"""
result = llm.invoke(message)
return {"response": result.content}
# 调用示例
# curl -X POST "http://localhost:8000/protected/chat?message=你好" \\
# -H "X-API-Key: my_secret_key_12345"
# 2. 多个API Key
VALID_KEYS = {
"key_user1": {"user": "user1", "tier": "basic"},
"key_user2": {"user": "user2", "tier": "premium"},
"key_admin": {"user": "admin", "tier": "admin"}
}
def verify_api_key_advanced(x_api_key: str = Header(None)):
"""高级API Key验证"""
if not x_api_key or x_api_key not in VALID_KEYS:
raise HTTPException(
status_code=401,
detail="无效的API Key"
)
# 返回用户信息
return VALID_KEYS[x_api_key]
@app.post("/advanced/chat")
def advanced_chat(
message: str,
user_info: dict = Depends(verify_api_key_advanced)
):
"""带用户信息的聊天"""
return {
"user": user_info["user"],
"tier": user_info["tier"],
"response": llm.invoke(message).content
}
# 3. 数据库存储的API Key
import hashlib
# 模拟数据库
api_keys_db = {
hashlib.sha256("user1_key".encode()).hexdigest(): {
"user_id": "user1",
"permissions": ["chat", "translate"]
},
hashlib.sha256("user2_key".encode()).hexdigest(): {
"user_id": "user2",
"permissions": ["chat"]
}
}
def verify_from_db(x_api_key: str = Header(None)):
"""从数据库验证API Key"""
if not x_api_key:
raise HTTPException(status_code=401, detail="缺少API Key")
# 哈希Key
hashed_key = hashlib.sha256(x_api_key.encode()).hexdigest()
if hashed_key not in api_keys_db:
raise HTTPException(status_code=401, detail="无效的API Key")
return api_keys_db[hashed_key]
@app.post("/db_auth/chat")
def db_auth_chat(
message: str,
user_data: dict = Depends(verify_from_db)
):
"""数据库认证的聊天"""
# 检查权限
if "chat" not in user_data["permissions"]:
raise HTTPException(status_code=403, detail="无权限")
return {
"user_id": user_data["user_id"],
"response": llm.invoke(message).content
}
# 4. 带速率限制的API Key
from collections import defaultdict
from datetime import datetime, timedelta
class RateLimiter:
"""速率限制器"""
def __init__(self, max_requests=10, window_seconds=60):
self.max_requests = max_requests
self.window = timedelta(seconds=window_seconds)
self.requests = defaultdict(list)
def check(self, api_key: str) -> bool:
"""检查速率限制"""
now = datetime.now()
# 清理过期请求
self.requests[api_key] = [
req_time for req_time in self.requests[api_key]
if now - req_time < self.window
]
# 检查限制
if len(self.requests[api_key]) >= self.max_requests:
return False
# 记录请求
self.requests[api_key].append(now)
return True
rate_limiter = RateLimiter(max_requests=10, window_seconds=60)
def verify_with_rate_limit(x_api_key: str = Header(None)):
"""带速率限制的验证"""
if x_api_key != VALID_API_KEY:
raise HTTPException(status_code=401, detail="无效的API Key")
if not rate_limiter.check(x_api_key):
raise HTTPException(
status_code=429,
detail="请求过于频繁,请稍后重试"
)
return x_api_key
@app.post("/rate_limited/chat")
def rate_limited_chat(
message: str,
api_key: str = Depends(verify_with_rate_limit)
):
"""速率限制的聊天"""
return llm.invoke(message).content
# 5. 为LangServe路由添加认证
from fastapi import APIRouter
# 创建受保护的路由组
protected_router = APIRouter(
prefix="/protected",
dependencies=[Depends(verify_api_key)]
)
# 添加LangServe路由到受保护组
add_routes(protected_router, llm, path="/llm")
# 注册路由
app.include_router(protected_router)
# 现在/protected/llm/*需要API Key
# 6. RemoteRunnable客户端with API Key
from langserve import RemoteRunnable
# 客户端配置
remote_chain = RemoteRunnable(
"http://localhost:8000/protected/llm",
headers={"X-API-Key": "my_secret_key_12345"}
)
# 使用
result = remote_chain.invoke({"message": "你好"})
# 7. 信创环境API Key认证
# 达梦数据库存储API Key
import dmPython
class DmApiKeyValidator:
"""达梦数据库API Key验证器"""
def __init__(self, connection_string: str):
self.conn = dmPython.connect(connection_string)
self._init_table()
def _init_table(self):
"""初始化表"""
cursor = self.conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS api_keys (
api_key VARCHAR(100) PRIMARY KEY,
user_id VARCHAR(50),
permissions VARCHAR(500),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP,
is_active INTEGER DEFAULT 1
)
""")
self.conn.commit()
def validate(self, api_key: str) -> dict:
"""验证API Key"""
cursor = self.conn.cursor()
cursor.execute("""
SELECT user_id, permissions, expires_at, is_active
FROM api_keys
WHERE api_key = ?
""", (api_key,))
row = cursor.fetchone()
if not row:
raise HTTPException(status_code=401, detail="无效的API Key")
user_id, permissions, expires_at, is_active = row
# 检查是否启用
if not is_active:
raise HTTPException(status_code=401, detail="API Key已禁用")
# 检查过期
if expires_at and datetime.now() > expires_at:
raise HTTPException(status_code=401, detail="API Key已过期")
return {
"user_id": user_id,
"permissions": permissions.split(",") if permissions else []
}
# 使用
dm_validator = DmApiKeyValidator(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236"
)
def verify_dm_api_key(x_api_key: str = Header(None)):
"""达梦API Key验证"""
return dm_validator.validate(x_api_key)
@app.post("/xinchuang/chat")
def xinchuang_chat(
question: str,
user_info: dict = Depends(verify_dm_api_key)
):
"""信创聊天端点"""
return {
"user_id": user_info["user_id"],
"response": "..."
}
print("✓ API Key认证配置完成")
---
b.JWT认证
a.功能说明
使用JWT(JSON Web Token)进行身份认证,安全性更高。Token包含用户信息和签名,可验证完整性。支持过期时间、权限控制。JWT是现代Web应用的标准认证方式。
b.代码示例
---
from fastapi import FastAPI, Depends, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
from datetime import datetime, timedelta
from pydantic import BaseModel
app = FastAPI()
# JWT配置
SECRET_KEY = "your-secret-key-keep-it-secret"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
security = HTTPBearer()
# 1. 生成JWT Token
def create_access_token(data: dict, expires_delta: timedelta = None):
"""创建访问Token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# 2. 验证JWT Token
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""验证JWT Token"""
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=401, detail="无效的Token")
return payload
except JWTError:
raise HTTPException(status_code=401, detail="无效的Token")
# 3. 登录端点(生成Token)
class LoginRequest(BaseModel):
username: str
password: str
@app.post("/login")
def login(request: LoginRequest):
"""登录并获取Token"""
# 验证用户名密码(实际应从数据库验证)
if request.username == "user" and request.password == "pass":
# 创建Token
access_token = create_access_token(
data={"sub": request.username},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
return {
"access_token": access_token,
"token_type": "bearer"
}
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 4. 受保护的端点
@app.post("/protected/chat")
def protected_chat(
message: str,
token_data: dict = Depends(verify_token)
):
"""受保护的聊天端点"""
user_id = token_data.get("sub")
# 处理聊天...
return {
"user": user_id,
"response": "..."
}
# 5. 带权限的JWT
def create_token_with_permissions(user_id: str, permissions: list):
"""创建带权限的Token"""
token_data = {
"sub": user_id,
"permissions": permissions
}
return create_access_token(token_data)
def verify_permission(required_permission: str):
"""权限验证装饰器"""
def permission_checker(token_data: dict = Depends(verify_token)):
permissions = token_data.get("permissions", [])
if required_permission not in permissions:
raise HTTPException(
status_code=403,
detail=f"需要权限:{required_permission}"
)
return token_data
return permission_checker
# 使用
@app.post("/admin/chat")
def admin_chat(
message: str,
token_data: dict = Depends(verify_permission("admin"))
):
"""管理员聊天端点"""
return {"response": "..."}
# 6. RemoteRunnable with JWT
from langserve import RemoteRunnable
# 先登录获取Token
import requests
login_response = requests.post(
"http://localhost:8000/login",
json={"username": "user", "password": "pass"}
)
token = login_response.json()["access_token"]
# 使用Token调用
remote_chain = RemoteRunnable(
"http://localhost:8000/protected/llm",
headers={"Authorization": f"Bearer {token}"}
)
result = remote_chain.invoke({"message": "你好"})
# 7. Token刷新
REFRESH_TOKEN_EXPIRE_DAYS = 7
def create_refresh_token(user_id: str):
"""创建刷新Token"""
return create_access_token(
data={"sub": user_id, "type": "refresh"},
expires_delta=timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)
@app.post("/refresh")
def refresh_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""刷新访问Token"""
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
if payload.get("type") != "refresh":
raise HTTPException(status_code=401, detail="无效的刷新Token")
user_id = payload.get("sub")
# 生成新的访问Token
new_access_token = create_access_token(data={"sub": user_id})
return {
"access_token": new_access_token,
"token_type": "bearer"
}
except JWTError:
raise HTTPException(status_code=401, detail="无效的Token")
# 8. 信创环境JWT(国密SM2)
from gmssl import sm2, func
class SM2JWT:
"""基于国密SM2的JWT"""
def __init__(self, private_key: str, public_key: str):
self.private_key = private_key
self.public_key = public_key
self.sm2_crypt = sm2.CryptSM2(
private_key=private_key,
public_key=public_key
)
def create_token(self, data: dict) -> str:
"""创建Token"""
import json
import base64
# 添加过期时间
data["exp"] = (datetime.utcnow() + timedelta(hours=1)).timestamp()
# 序列化
payload = json.dumps(data).encode()
# SM2签名
signature = self.sm2_crypt.sign(payload, None)
# Base64编码
token = base64.urlsafe_b64encode(
payload + b"|" + signature
).decode()
return token
def verify_token(self, token: str) -> dict:
"""验证Token"""
import json
import base64
# Base64解码
decoded = base64.urlsafe_b64decode(token.encode())
# 分离payload和签名
payload, signature = decoded.split(b"|", 1)
# 验证签名
if not self.sm2_crypt.verify(signature, payload, None):
raise HTTPException(status_code=401, detail="签名验证失败")
# 解析payload
data = json.loads(payload.decode())
# 检查过期
if datetime.utcnow().timestamp() > data["exp"]:
raise HTTPException(status_code=401, detail="Token已过期")
return data
# 使用
sm2_jwt = SM2JWT(
private_key="your_private_key",
public_key="your_public_key"
)
@app.post("/xinchuang/login")
def xinchuang_login(request: LoginRequest):
"""信创登录"""
# 验证用户...
token = sm2_jwt.create_token({"sub": request.username})
return {"access_token": token, "token_type": "bearer"}
def verify_sm2_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""验证SM2 Token"""
return sm2_jwt.verify_token(credentials.credentials)
@app.post("/xinchuang/protected/chat")
def xinchuang_protected_chat(
question: str,
token_data: dict = Depends(verify_sm2_token)
):
"""信创受保护端点"""
return {
"user": token_data["sub"],
"response": "..."
}
print("✓ JWT认证配置完成")
---
02.权限控制
a.基于角色
a.功能说明
根据用户角色分配不同的访问权限。定义角色(如admin、user、guest)和权限(如read、write)。实现细粒度的访问控制。RBAC(基于角色的访问控制)是企业应用的标准模式。
b.代码示例
---
from fastapi import FastAPI, Depends, HTTPException
from enum import Enum
from typing import List
app = FastAPI()
# 1. 定义角色和权限
class Role(str, Enum):
ADMIN = "admin"
USER = "user"
GUEST = "guest"
class Permission(str, Enum):
CHAT = "chat"
TRANSLATE = "translate"
ADMIN_PANEL = "admin_panel"
# 角色权限映射
ROLE_PERMISSIONS = {
Role.ADMIN: [Permission.CHAT, Permission.TRANSLATE, Permission.ADMIN_PANEL],
Role.USER: [Permission.CHAT, Permission.TRANSLATE],
Role.GUEST: [Permission.CHAT]
}
# 2. 用户数据库(模拟)
users_db = {
"admin": {"password": "admin123", "role": Role.ADMIN},
"user1": {"password": "pass123", "role": Role.USER},
"guest1": {"password": "guest123", "role": Role.GUEST}
}
# 3. 获取当前用户
def get_current_user(x_user_id: str = Header(None)):
"""获取当前用户"""
if not x_user_id or x_user_id not in users_db:
raise HTTPException(status_code=401, detail="未认证")
return {
"user_id": x_user_id,
"role": users_db[x_user_id]["role"]
}
# 4. 检查权限
def has_permission(required_permission: Permission):
"""权限检查依赖"""
def permission_checker(user: dict = Depends(get_current_user)):
role = user["role"]
permissions = ROLE_PERMISSIONS.get(role, [])
if required_permission not in permissions:
raise HTTPException(
status_code=403,
detail=f"需要权限:{required_permission.value}"
)
return user
return permission_checker
# 5. 受保护的端点
@app.post("/chat")
def chat(
message: str,
user: dict = Depends(has_permission(Permission.CHAT))
):
"""聊天端点(需要chat权限)"""
return {
"user": user["user_id"],
"role": user["role"],
"response": "..."
}
@app.post("/translate")
def translate(
text: str,
user: dict = Depends(has_permission(Permission.TRANSLATE))
):
"""翻译端点(需要translate权限)"""
return {"translated": "..."}
@app.get("/admin")
def admin_panel(
user: dict = Depends(has_permission(Permission.ADMIN_PANEL))
):
"""管理面板(仅管理员)"""
return {"message": "欢迎,管理员"}
# 6. 检查角色
def require_role(required_role: Role):
"""角色检查依赖"""
def role_checker(user: dict = Depends(get_current_user)):
if user["role"] != required_role:
raise HTTPException(
status_code=403,
detail=f"需要角色:{required_role.value}"
)
return user
return role_checker
@app.post("/admin/chat")
def admin_chat(
message: str,
user: dict = Depends(require_role(Role.ADMIN))
):
"""管理员聊天"""
return {"response": "..."}
# 7. 多角色检查
def require_any_role(roles: List[Role]):
"""要求任一角色"""
def role_checker(user: dict = Depends(get_current_user)):
if user["role"] not in roles:
raise HTTPException(
status_code=403,
detail=f"需要角色:{[r.value for r in roles]}"
)
return user
return role_checker
@app.post("/premium/chat")
def premium_chat(
message: str,
user: dict = Depends(require_any_role([Role.ADMIN, Role.USER]))
):
"""高级聊天(管理员或用户)"""
return {"response": "..."}
# 8. 信创环境RBAC(达梦数据库)
import dmPython
class DmRBACManager:
"""达梦数据库RBAC管理器"""
def __init__(self, connection_string: str):
self.conn = dmPython.connect(connection_string)
self._init_tables()
def _init_tables(self):
"""初始化表"""
cursor = self.conn.cursor()
# 角色表
cursor.execute("""
CREATE TABLE IF NOT EXISTS roles (
role_id VARCHAR(50) PRIMARY KEY,
role_name VARCHAR(100),
description VARCHAR(500)
)
""")
# 权限表
cursor.execute("""
CREATE TABLE IF NOT EXISTS permissions (
permission_id VARCHAR(50) PRIMARY KEY,
permission_name VARCHAR(100),
resource VARCHAR(100)
)
""")
# 角色-权限关联表
cursor.execute("""
CREATE TABLE IF NOT EXISTS role_permissions (
role_id VARCHAR(50),
permission_id VARCHAR(50),
PRIMARY KEY (role_id, permission_id)
)
""")
# 用户-角色关联表
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_roles (
user_id VARCHAR(50),
role_id VARCHAR(50),
PRIMARY KEY (user_id, role_id)
)
""")
self.conn.commit()
def get_user_permissions(self, user_id: str) -> List[str]:
"""获取用户权限"""
cursor = self.conn.cursor()
cursor.execute("""
SELECT DISTINCT p.permission_id
FROM permissions p
JOIN role_permissions rp ON p.permission_id = rp.permission_id
JOIN user_roles ur ON rp.role_id = ur.role_id
WHERE ur.user_id = ?
""", (user_id,))
return [row[0] for row in cursor.fetchall()]
def has_permission(self, user_id: str, permission: str) -> bool:
"""检查用户权限"""
permissions = self.get_user_permissions(user_id)
return permission in permissions
# 使用
rbac = DmRBACManager(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236"
)
def verify_xinchuang_permission(required_permission: str):
"""信创权限验证"""
def checker(x_user_id: str = Header(None)):
if not x_user_id:
raise HTTPException(status_code=401, detail="未认证")
if not rbac.has_permission(x_user_id, required_permission):
raise HTTPException(
status_code=403,
detail=f"缺少权限:{required_permission}"
)
return x_user_id
return checker
@app.post("/xinchuang/chat")
def xinchuang_chat(
question: str,
user_id: str = Depends(verify_xinchuang_permission("chat"))
):
"""信创聊天端点"""
return {
"user_id": user_id,
"response": "..."
}
print("✓ 基于角色的权限控制配置完成")
---
4.2 CORS配置
01.跨域设置
a.允许来源
a.功能说明
配置CORS(跨源资源共享)允许不同域名的Web应用访问API。设置允许的来源(origins)、方法(methods)、请求头(headers)。CORS是前后端分离架构的必备配置。正确配置CORS确保安全性和可用性。
b.代码示例
---
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
app = FastAPI()
# 1. 基础CORS配置
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"], # 允许的前端域名
allow_credentials=True,
allow_methods=["*"], # 允许所有HTTP方法
allow_headers=["*"] # 允许所有请求头
)
# 添加LangServe路由
llm = ChatOpenAI()
add_routes(app, llm, path="/chat")
# 2. 允许多个来源
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"http://localhost:8080",
"https://app.example.com",
"https://admin.example.com"
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# 3. 允许所有来源(仅开发环境)
import os
if os.getenv("ENVIRONMENT") == "development":
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# 4. 限制HTTP方法
app.add_middleware(
CORSMiddleware,
allow_origins=["https://app.example.com"],
allow_credentials=True,
allow_methods=["GET", "POST"], # 仅允许GET和POST
allow_headers=["*"]
)
# 5. 限制请求头
app.add_middleware(
CORSMiddleware,
allow_origins=["https://app.example.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=[
"Content-Type",
"Authorization",
"X-API-Key"
]
)
# 6. 设置响应头
app.add_middleware(
CORSMiddleware,
allow_origins=["https://app.example.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["X-Custom-Header", "X-Request-ID"], # 暴露给客户端的响应头
max_age=3600 # 预检请求缓存时间(秒)
)
# 7. 动态CORS配置
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
class DynamicCORSMiddleware(BaseHTTPMiddleware):
"""动态CORS中间件"""
def __init__(self, app, allowed_domains: list):
super().__init__(app)
self.allowed_domains = allowed_domains
async def dispatch(self, request: Request, call_next):
origin = request.headers.get("origin")
response = await call_next(request)
# 检查来源是否允许
if origin and any(domain in origin for domain in self.allowed_domains):
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Allow-Methods"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
return response
# 使用
app.add_middleware(
DynamicCORSMiddleware,
allowed_domains=["example.com", "myapp.com"]
)
# 8. 环境特定配置
from typing import List
def get_allowed_origins() -> List[str]:
"""根据环境获取允许的来源"""
env = os.getenv("ENVIRONMENT", "production")
if env == "development":
return ["*"]
elif env == "staging":
return [
"http://localhost:3000",
"https://staging.example.com"
]
else: # production
return [
"https://app.example.com",
"https://admin.example.com"
]
app.add_middleware(
CORSMiddleware,
allow_origins=get_allowed_origins(),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# 9. 子域名通配符
# 注意:CORSMiddleware不直接支持通配符,需自定义
class WildcardCORSMiddleware(BaseHTTPMiddleware):
"""支持通配符的CORS中间件"""
def __init__(self, app, allowed_origin_patterns: list):
super().__init__(app)
self.patterns = allowed_origin_patterns
async def dispatch(self, request: Request, call_next):
origin = request.headers.get("origin", "")
response = await call_next(request)
# 检查是否匹配模式
allowed = any(
self.match_pattern(origin, pattern)
for pattern in self.patterns
)
if allowed:
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Allow-Methods"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
return response
def match_pattern(self, origin: str, pattern: str) -> bool:
"""匹配模式"""
import re
# 将*.example.com转换为正则表达式
regex = pattern.replace("*", ".*").replace(".", "\\.")
return bool(re.match(f"^{regex}$", origin))
# 使用
app.add_middleware(
WildcardCORSMiddleware,
allowed_origin_patterns=[
"https://*.example.com",
"http://localhost:*"
]
)
# 10. 信创环境CORS配置
# 信创浏览器(如360、红莲花)CORS配置
xinchuang_origins = [
"http://localhost:8080", # 本地开发
"http://192.168.1.100:8080", # 内网IP
"https://internal.company.com" # 内部域名
]
app.add_middleware(
CORSMiddleware,
allow_origins=xinchuang_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=[
"Content-Type",
"Authorization",
"X-API-Key",
"X-User-ID"
],
expose_headers=["X-Request-ID", "X-Response-Time"],
max_age=86400 # 24小时缓存
)
# 添加自定义响应头
@app.middleware("http")
async def add_xinchuang_headers(request: Request, call_next):
"""添加信创环境响应头"""
response = await call_next(request)
# 添加服务器信息
response.headers["X-Powered-By"] = "LangServe-Xinchuang"
response.headers["X-Server-Type"] = "Kylin-OS"
return response
print("✓ CORS配置完成")
---
b.预检请求
a.功能说明
浏览器在发送跨域请求前会发送OPTIONS预检请求,检查是否允许。服务器需正确响应预检请求,返回允许的方法和头。预检请求缓存可以减少网络开销。理解预检机制是配置CORS的关键。
b.代码示例
---
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# 1. 自动处理预检请求
# CORSMiddleware会自动处理OPTIONS请求
app.add_middleware(
CORSMiddleware,
allow_origins=["https://app.example.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
max_age=3600 # 预检缓存1小时
)
# 2. 手动处理预检请求
@app.options("/custom/chat")
async def chat_options():
"""手动处理预检请求"""
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": "https://app.example.com",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
"Access-Control-Max-Age": "3600"
}
)
@app.post("/custom/chat")
async def chat(request: Request):
"""实际的聊天端点"""
# 添加CORS响应头
return Response(
content='{"response": "..."}',
media_type="application/json",
headers={
"Access-Control-Allow-Origin": "https://app.example.com",
"Access-Control-Allow-Credentials": "true"
}
)
# 3. 日志预检请求
@app.middleware("http")
async def log_preflight(request: Request, call_next):
"""记录预检请求"""
if request.method == "OPTIONS":
print(f"预检请求:{request.url}")
print(f"来源:{request.headers.get('origin')}")
print(f"请求方法:{request.headers.get('access-control-request-method')}")
print(f"请求头:{request.headers.get('access-control-request-headers')}")
response = await call_next(request)
return response
# 4. 条件预检缓存
class AdaptivePreflightMiddleware:
"""自适应预检缓存中间件"""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] == "http" and scope["method"] == "OPTIONS":
# 根据请求复杂度调整缓存时间
headers = dict(scope["headers"])
request_headers = headers.get(b"access-control-request-headers", b"").decode()
if len(request_headers.split(",")) > 5:
# 复杂请求,短缓存
max_age = "600" # 10分钟
else:
# 简单请求,长缓存
max_age = "86400" # 24小时
# 修改响应头
# (实际实现需要更复杂的逻辑)
await self.app(scope, receive, send)
# 5. 预检失败处理
from starlette.middleware.base import BaseHTTPMiddleware
class PreflightErrorHandler(BaseHTTPMiddleware):
"""预检错误处理"""
async def dispatch(self, request: Request, call_next):
if request.method == "OPTIONS":
origin = request.headers.get("origin")
# 检查来源
allowed_origins = ["https://app.example.com"]
if origin not in allowed_origins:
return Response(
status_code=403,
content="不允许的来源",
headers={
"Content-Type": "text/plain"
}
)
return await call_next(request)
app.add_middleware(PreflightErrorHandler)
# 6. 预检统计
preflight_stats = {"total": 0, "by_origin": {}}
@app.middleware("http")
async def track_preflight(request: Request, call_next):
"""统计预检请求"""
if request.method == "OPTIONS":
preflight_stats["total"] += 1
origin = request.headers.get("origin", "unknown")
preflight_stats["by_origin"][origin] = \
preflight_stats["by_origin"].get(origin, 0) + 1
return await call_next(request)
@app.get("/cors/stats")
def get_cors_stats():
"""获取CORS统计"""
return preflight_stats
# 7. 优化预检性能
# 使用长缓存减少预检请求
app.add_middleware(
CORSMiddleware,
allow_origins=["https://app.example.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
max_age=86400 # 24小时,减少重复预检
)
# 8. 信创浏览器预检适配
# 某些国产浏览器对预检的处理可能有差异
@app.middleware("http")
async def xinchuang_preflight_adapter(request: Request, call_next):
"""信创浏览器预检适配"""
if request.method == "OPTIONS":
# 检测浏览器类型
user_agent = request.headers.get("user-agent", "").lower()
response = await call_next(request)
# 针对特定浏览器添加额外头
if "360" in user_agent or "qihu" in user_agent:
# 360浏览器
response.headers["X-Browser-Compat"] = "360"
return response
return await call_next(request)
# 9. 预检请求验证
@app.middleware("http")
async def validate_preflight(request: Request, call_next):
"""验证预检请求合法性"""
if request.method == "OPTIONS":
# 检查必需的预检头
required_headers = [
"origin",
"access-control-request-method"
]
for header in required_headers:
if header not in request.headers:
return Response(
status_code=400,
content=f"缺少预检头:{header}"
)
return await call_next(request)
print("✓ 预检请求配置完成")
---
02.安全实践
a.限制来源
a.功能说明
严格限制允许的来源,防止未授权访问。生产环境不使用通配符(*),明确指定域名。定期审查和更新允许的来源列表。限制来源是CORS安全的核心。
b.代码示例
---
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import os
app = FastAPI()
# 1. 生产环境严格配置
if os.getenv("ENVIRONMENT") == "production":
# 生产环境:仅允许特定域名
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://app.example.com",
"https://admin.example.com"
],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["Content-Type", "Authorization"]
)
else:
# 开发环境:宽松配置
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False, # 通配符时不能为True
allow_methods=["*"],
allow_headers=["*"]
)
# 2. 白名单管理
ALLOWED_ORIGINS_WHITELIST = [
"https://app.example.com",
"https://admin.example.com",
"https://partner.example.com"
]
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
class WhitelistCORSMiddleware(BaseHTTPMiddleware):
"""白名单CORS中间件"""
def __init__(self, app, whitelist: list):
super().__init__(app)
self.whitelist = set(whitelist)
async def dispatch(self, request: Request, call_next):
origin = request.headers.get("origin")
response = await call_next(request)
# 仅允许白名单中的来源
if origin and origin in self.whitelist:
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Allow-Methods"] = "GET, POST"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
return response
app.add_middleware(WhitelistCORSMiddleware, whitelist=ALLOWED_ORIGINS_WHITELIST)
# 3. 动态白名单(从数据库加载)
class DynamicWhitelistManager:
"""动态白名单管理器"""
def __init__(self):
self.whitelist = set()
self.load_whitelist()
def load_whitelist(self):
"""从数据库加载白名单"""
# 实际应从数据库读取
self.whitelist = {
"https://app.example.com",
"https://admin.example.com"
}
def is_allowed(self, origin: str) -> bool:
"""检查来源是否允许"""
return origin in self.whitelist
def add_origin(self, origin: str):
"""添加来源"""
self.whitelist.add(origin)
def remove_origin(self, origin: str):
"""移除来源"""
self.whitelist.discard(origin)
whitelist_manager = DynamicWhitelistManager()
@app.middleware("http")
async def dynamic_cors(request: Request, call_next):
"""动态CORS检查"""
origin = request.headers.get("origin")
response = await call_next(request)
if origin and whitelist_manager.is_allowed(origin):
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true"
return response
# 管理端点
@app.post("/admin/cors/add")
async def add_origin(origin: str):
"""添加允许的来源"""
whitelist_manager.add_origin(origin)
return {"message": f"已添加:{origin}"}
@app.delete("/admin/cors/remove")
async def remove_origin(origin: str):
"""移除允许的来源"""
whitelist_manager.remove_origin(origin)
return {"message": f"已移除:{origin}"}
# 4. IP白名单
from ipaddress import ip_address, ip_network
ALLOWED_IP_RANGES = [
"192.168.1.0/24", # 内网
"10.0.0.0/8" # 内网
]
@app.middleware("http")
async def ip_whitelist(request: Request, call_next):
"""IP白名单检查"""
client_ip = request.client.host
allowed = any(
ip_address(client_ip) in ip_network(range_str)
for range_str in ALLOWED_IP_RANGES
)
if not allowed:
return Response(status_code=403, content="IP不在白名单")
return await call_next(request)
# 5. 时间限制
from datetime import datetime, time
@app.middleware("http")
async def time_based_cors(request: Request, call_next):
"""基于时间的CORS"""
current_hour = datetime.now().hour
# 仅工作时间(9-18点)允许跨域
if 9 <= current_hour < 18:
response = await call_next(request)
origin = request.headers.get("origin")
if origin:
response.headers["Access-Control-Allow-Origin"] = origin
return response
else:
return Response(status_code=403, content="非工作时间")
# 6. 请求频率限制
from collections import defaultdict
from datetime import datetime, timedelta
origin_requests = defaultdict(list)
@app.middleware("http")
async def rate_limit_by_origin(request: Request, call_next):
"""按来源限流"""
origin = request.headers.get("origin", "unknown")
now = datetime.now()
# 清理过期记录
origin_requests[origin] = [
t for t in origin_requests[origin]
if now - t < timedelta(minutes=1)
]
# 检查频率(每分钟最多100次)
if len(origin_requests[origin]) >= 100:
return Response(
status_code=429,
content="请求过于频繁"
)
origin_requests[origin].append(now)
return await call_next(request)
# 7. 信创环境安全配置
# 仅允许内网来源
XINCHUANG_ALLOWED_ORIGINS = [
"http://localhost:8080",
"http://192.168.1.100:8080",
"https://internal.company.com"
]
app.add_middleware(
CORSMiddleware,
allow_origins=XINCHUANG_ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST"], # 限制方法
allow_headers=["Content-Type", "Authorization", "X-User-ID"], # 限制头
max_age=3600
)
# 额外安全检查
@app.middleware("http")
async def xinchuang_security_check(request: Request, call_next):
"""信创安全检查"""
origin = request.headers.get("origin")
# 拒绝外网来源
if origin and not any(
origin.startswith(allowed)
for allowed in ["http://localhost", "http://192.168", "https://internal"]
):
return Response(
status_code=403,
content="仅允许内网访问"
)
return await call_next(request)
print("✓ CORS安全实践配置完成")
---
4.3 中间件
01.请求处理
a.日志中间件
a.功能说明
中间件在请求处理前后执行,实现横切关注点。日志中间件记录所有请求和响应,便于调试和监控。记录请求时间、状态码、处理时长等信息。日志是生产环境排查问题的关键。
b.代码示例
---
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import time
import logging
app = FastAPI()
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 1. 基础日志中间件
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""记录请求日志"""
start_time = time.time()
# 记录请求
logger.info(f"请求开始:{request.method} {request.url}")
# 处理请求
response = await call_next(request)
# 记录响应
process_time = time.time() - start_time
logger.info(
f"请求完成:{request.method} {request.url} "
f"状态码={response.status_code} 耗时={process_time:.3f}秒"
)
return response
# 2. 详细日志中间件
class DetailedLoggingMiddleware(BaseHTTPMiddleware):
"""详细日志中间件"""
async def dispatch(self, request: Request, call_next):
"""处理请求"""
start_time = time.time()
request_id = str(time.time())
# 记录请求详情
logger.info(f"[{request_id}] 请求开始")
logger.info(f"[{request_id}] 方法: {request.method}")
logger.info(f"[{request_id}] URL: {request.url}")
logger.info(f"[{request_id}] 客户端: {request.client.host}")
logger.info(f"[{request_id}] User-Agent: {request.headers.get('user-agent')}")
# 处理请求
try:
response = await call_next(request)
# 记录响应
process_time = time.time() - start_time
logger.info(f"[{request_id}] 状态码: {response.status_code}")
logger.info(f"[{request_id}] 耗时: {process_time:.3f}秒")
# 添加响应头
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = str(process_time)
return response
except Exception as e:
process_time = time.time() - start_time
logger.error(f"[{request_id}] 错误: {str(e)}")
logger.error(f"[{request_id}] 耗时: {process_time:.3f}秒")
raise
app.add_middleware(DetailedLoggingMiddleware)
# 3. 结构化日志
import json
class StructuredLoggingMiddleware(BaseHTTPMiddleware):
"""结构化日志中间件"""
async def dispatch(self, request: Request, call_next):
start_time = time.time()
# 构建日志对象
log_data = {
"timestamp": time.time(),
"method": request.method,
"url": str(request.url),
"client_ip": request.client.host,
"user_agent": request.headers.get("user-agent")
}
response = await call_next(request)
# 添加响应信息
log_data.update({
"status_code": response.status_code,
"process_time": time.time() - start_time
})
# 输出JSON格式日志
logger.info(json.dumps(log_data))
return response
# 4. 条件日志
@app.middleware("http")
async def conditional_logging(request: Request, call_next):
"""条件日志"""
# 仅记录特定路径
should_log = request.url.path.startswith("/api/")
if should_log:
start_time = time.time()
logger.info(f"API请求:{request.url}")
response = await call_next(request)
if should_log:
logger.info(f"API响应:耗时{time.time() - start_time:.3f}秒")
return response
# 5. 日志轮转(文件日志)
from logging.handlers import RotatingFileHandler
# 配置文件日志
file_handler = RotatingFileHandler(
"langserve.log",
maxBytes=10*1024*1024, # 10MB
backupCount=5
)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s'
))
logger.addHandler(file_handler)
# 6. 敏感信息过滤
class SecureLoggingMiddleware(BaseHTTPMiddleware):
"""安全日志中间件(过滤敏感信息)"""
SENSITIVE_HEADERS = ["authorization", "x-api-key", "cookie"]
async def dispatch(self, request: Request, call_next):
start_time = time.time()
# 过滤敏感头
safe_headers = {
k: v if k.lower() not in self.SENSITIVE_HEADERS else "***"
for k, v in request.headers.items()
}
logger.info(f"请求:{request.method} {request.url}")
logger.debug(f"请求头:{safe_headers}")
response = await call_next(request)
logger.info(
f"响应:状态={response.status_code} "
f"耗时={time.time() - start_time:.3f}秒"
)
return response
# 7. 日志聚合
class AggregatedLoggingMiddleware(BaseHTTPMiddleware):
"""日志聚合中间件"""
def __init__(self, app):
super().__init__(app)
self.stats = {
"total_requests": 0,
"by_method": {},
"by_status": {},
"total_time": 0.0
}
async def dispatch(self, request: Request, call_next):
start_time = time.time()
response = await call_next(request)
# 更新统计
self.stats["total_requests"] += 1
method = request.method
self.stats["by_method"][method] = \
self.stats["by_method"].get(method, 0) + 1
status = response.status_code
self.stats["by_status"][status] = \
self.stats["by_status"].get(status, 0) + 1
process_time = time.time() - start_time
self.stats["total_time"] += process_time
# 每100个请求打印统计
if self.stats["total_requests"] % 100 == 0:
logger.info(f"统计:{self.stats}")
return response
# 8. 信创环境日志
class XinchuangLoggingMiddleware(BaseHTTPMiddleware):
"""信创环境日志中间件"""
async def dispatch(self, request: Request, call_next):
start_time = time.time()
# 记录信创环境特定信息
log_data = {
"timestamp": time.time(),
"method": request.method,
"url": str(request.url),
"client_ip": request.client.host,
"server_type": "Kylin-OS", # 麒麟系统
"llm_provider": "ollama", # 本地LLM
"db_type": "dameng" # 达梦数据库
}
try:
response = await call_next(request)
log_data.update({
"status_code": response.status_code,
"process_time": time.time() - start_time,
"success": True
})
# 记录到达梦数据库
# self.log_to_dameng(log_data)
logger.info(f"信创日志:{json.dumps(log_data, ensure_ascii=False)}")
return response
except Exception as e:
log_data.update({
"success": False,
"error": str(e),
"process_time": time.time() - start_time
})
logger.error(f"信创错误:{json.dumps(log_data, ensure_ascii=False)}")
raise
print("✓ 日志中间件配置完成")
---
b.性能监控
a.功能说明
监控中间件追踪请求性能,识别慢请求和瓶颈。记录响应时间、内存使用、CPU占用等指标。生成性能报告,优化系统性能。性能监控是保障服务质量的基础。
b.代码示例
---
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import time
import psutil
app = FastAPI()
# 1. 响应时间监控
@app.middleware("http")
async def monitor_response_time(request: Request, call_next):
"""监控响应时间"""
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
# 添加响应头
response.headers["X-Process-Time"] = f"{process_time:.3f}"
# 慢请求警告
if process_time > 1.0:
logger.warning(
f"慢请求:{request.method} {request.url} "
f"耗时{process_time:.3f}秒"
)
return response
# 2. 资源使用监控
class ResourceMonitoringMiddleware(BaseHTTPMiddleware):
"""资源使用监控中间件"""
async def dispatch(self, request: Request, call_next):
# 获取进程信息
process = psutil.Process()
# 请求前资源使用
cpu_before = process.cpu_percent()
memory_before = process.memory_info().rss / 1024 / 1024 # MB
start_time = time.time()
response = await call_next(request)
# 请求后资源使用
cpu_after = process.cpu_percent()
memory_after = process.memory_info().rss / 1024 / 1024
process_time = time.time() - start_time
# 记录资源使用
logger.info(
f"资源使用:CPU={cpu_after:.1f}% "
f"内存={memory_after:.1f}MB "
f"耗时={process_time:.3f}秒"
)
# 添加响应头
response.headers["X-CPU-Usage"] = f"{cpu_after:.1f}"
response.headers["X-Memory-MB"] = f"{memory_after:.1f}"
return response
app.add_middleware(ResourceMonitoringMiddleware)
# 3. 性能统计
class PerformanceStatsMiddleware(BaseHTTPMiddleware):
"""性能统计中间件"""
def __init__(self, app):
super().__init__(app)
self.request_times = []
self.endpoint_times = {}
async def dispatch(self, request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
# 记录时间
self.request_times.append(process_time)
# 按端点统计
endpoint = f"{request.method} {request.url.path}"
if endpoint not in self.endpoint_times:
self.endpoint_times[endpoint] = []
self.endpoint_times[endpoint].append(process_time)
# 保持最近1000个请求
if len(self.request_times) > 1000:
self.request_times = self.request_times[-1000:]
return response
def get_stats(self):
"""获取统计信息"""
if not self.request_times:
return {}
import statistics
return {
"total_requests": len(self.request_times),
"avg_time": statistics.mean(self.request_times),
"median_time": statistics.median(self.request_times),
"min_time": min(self.request_times),
"max_time": max(self.request_times),
"p95_time": statistics.quantiles(self.request_times, n=20)[18],
"p99_time": statistics.quantiles(self.request_times, n=100)[98]
}
perf_stats = PerformanceStatsMiddleware(app)
app.add_middleware(lambda app: perf_stats)
@app.get("/metrics/performance")
def get_performance_metrics():
"""获取性能指标"""
return perf_stats.get_stats()
# 4. 实时性能追踪
from collections import deque
from datetime import datetime, timedelta
class RealTimePerformanceTracker(BaseHTTPMiddleware):
"""实时性能追踪器"""
def __init__(self, app, window_seconds=60):
super().__init__(app)
self.window = timedelta(seconds=window_seconds)
self.requests = deque() # (timestamp, duration)
async def dispatch(self, request: Request, call_next):
start_time = time.time()
response = await call_next(request)
duration = time.time() - start_time
now = datetime.now()
# 添加新请求
self.requests.append((now, duration))
# 清理过期数据
while self.requests and now - self.requests[0][0] > self.window:
self.requests.popleft()
# 计算当前窗口统计
if self.requests:
recent_times = [d for _, d in self.requests]
avg_time = sum(recent_times) / len(recent_times)
response.headers["X-Window-Requests"] = str(len(self.requests))
response.headers["X-Window-Avg-Time"] = f"{avg_time:.3f}"
return response
# 5. 性能阈值告警
class PerformanceAlertMiddleware(BaseHTTPMiddleware):
"""性能阈值告警中间件"""
def __init__(self, app, slow_threshold=1.0, very_slow_threshold=3.0):
super().__init__(app)
self.slow_threshold = slow_threshold
self.very_slow_threshold = very_slow_threshold
async def dispatch(self, request: Request, call_next):
start_time = time.time()
response = await call_next(request)
duration = time.time() - start_time
# 性能分级告警
if duration > self.very_slow_threshold:
logger.error(
f"[严重]超慢请求:{request.method} {request.url} "
f"耗时{duration:.3f}秒"
)
# 发送告警通知
self.send_alert("critical", request, duration)
elif duration > self.slow_threshold:
logger.warning(
f"[警告]慢请求:{request.method} {request.url} "
f"耗时{duration:.3f}秒"
)
return response
def send_alert(self, level: str, request: Request, duration: float):
"""发送告警"""
# 实际应发送到告警系统(钉钉、企业微信等)
print(f"告警[{level}]:{request.url} 耗时{duration:.3f}秒")
# 6. 性能分析
import cProfile
import pstats
import io
class ProfilingMiddleware(BaseHTTPMiddleware):
"""性能分析中间件(开发环境)"""
async def dispatch(self, request: Request, call_next):
# 是否启用分析
enable_profiling = request.query_params.get("profile") == "1"
if enable_profiling:
profiler = cProfile.Profile()
profiler.enable()
response = await call_next(request)
profiler.disable()
# 输出分析结果
s = io.StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative')
ps.print_stats(20) # 前20项
logger.info(f"性能分析:\n{s.getvalue()}")
return response
else:
return await call_next(request)
# 7. 信创环境性能监控
class XinchuangPerformanceMonitor(BaseHTTPMiddleware):
"""信创环境性能监控"""
async def dispatch(self, request: Request, call_next):
# 记录各阶段时间
times = {}
times["start"] = time.time()
# 数据库查询时间(模拟)
times["db_start"] = time.time()
# ... 数据库操作
times["db_end"] = time.time()
# LLM处理时间(模拟)
times["llm_start"] = time.time()
response = await call_next(request)
times["llm_end"] = time.time()
times["end"] = time.time()
# 计算各阶段耗时
db_time = times["db_end"] - times["db_start"]
llm_time = times["llm_end"] - times["llm_start"]
total_time = times["end"] - times["start"]
# 记录性能数据
perf_data = {
"total_time": total_time,
"db_time": db_time,
"llm_time": llm_time,
"other_time": total_time - db_time - llm_time,
"timestamp": time.time()
}
logger.info(f"信创性能:{json.dumps(perf_data)}")
# 添加到响应头
response.headers["X-Total-Time"] = f"{total_time:.3f}"
response.headers["X-DB-Time"] = f"{db_time:.3f}"
response.headers["X-LLM-Time"] = f"{llm_time:.3f}"
return response
print("✓ 性能监控中间件配置完成")
---
02.响应处理
a.压缩中间件
a.功能说明
压缩响应内容,减少网络传输数据量。支持gzip、brotli等压缩算法。自动根据客户端Accept-Encoding选择压缩方式。压缩可以显著提升传输速度,尤其对大响应。
b.代码示例
---
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
app = FastAPI()
# 1. 基础Gzip压缩
app.add_middleware(GZipMiddleware, minimum_size=1000) # 1KB以上才压缩
# 2. 自定义压缩中间件
import gzip
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
class CustomCompressionMiddleware(BaseHTTPMiddleware):
"""自定义压缩中间件"""
def __init__(self, app, minimum_size=1024, compression_level=6):
super().__init__(app)
self.minimum_size = minimum_size
self.compression_level = compression_level
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# 检查是否支持gzip
accept_encoding = request.headers.get("accept-encoding", "")
if "gzip" not in accept_encoding:
return response
# 检查响应大小
if hasattr(response, "body"):
body = response.body
if len(body) < self.minimum_size:
return response
# 压缩
compressed_body = gzip.compress(
body,
compresslevel=self.compression_level
)
# 创建新响应
response = Response(
content=compressed_body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
response.headers["Content-Encoding"] = "gzip"
response.headers["Content-Length"] = str(len(compressed_body))
return response
# 3. 条件压缩
@app.middleware("http")
async def conditional_compression(request: Request, call_next):
"""条件压缩"""
response = await call_next(request)
# 仅压缩text和json
content_type = response.headers.get("content-type", "")
should_compress = any(
ct in content_type
for ct in ["text/", "application/json", "application/javascript"]
)
if should_compress and hasattr(response, "body"):
# 执行压缩...
pass
return response
# 4. 压缩统计
class CompressionStatsMiddleware(BaseHTTPMiddleware):
"""压缩统计中间件"""
def __init__(self, app):
super().__init__(app)
self.stats = {
"total_responses": 0,
"compressed_responses": 0,
"bytes_before": 0,
"bytes_after": 0
}
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
self.stats["total_responses"] += 1
if hasattr(response, "body"):
original_size = len(response.body)
self.stats["bytes_before"] += original_size
# 检查是否被压缩
if response.headers.get("content-encoding") == "gzip":
self.stats["compressed_responses"] += 1
compressed_size = int(response.headers.get("content-length", original_size))
self.stats["bytes_after"] += compressed_size
else:
self.stats["bytes_after"] += original_size
return response
def get_compression_ratio(self):
"""获取压缩比"""
if self.stats["bytes_before"] == 0:
return 0
return 1 - (self.stats["bytes_after"] / self.stats["bytes_before"])
compression_stats = CompressionStatsMiddleware(app)
app.add_middleware(lambda app: compression_stats)
@app.get("/metrics/compression")
def get_compression_metrics():
"""获取压缩指标"""
return {
**compression_stats.stats,
"compression_ratio": compression_stats.get_compression_ratio()
}
# 5. 信创环境压缩配置
# 信创浏览器可能对某些压缩算法支持有限
class XinchuangCompressionMiddleware(BaseHTTPMiddleware):
"""信创兼容的压缩中间件"""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# 检测信创浏览器
user_agent = request.headers.get("user-agent", "").lower()
is_xinchuang_browser = any(
browser in user_agent
for browser in ["360", "qihu", "honglianghua"]
)
if is_xinchuang_browser:
# 使用兼容性更好的gzip,避免brotli
# 同时降低压缩级别提升性能
if hasattr(response, "body") and len(response.body) > 1024:
compressed = gzip.compress(response.body, compresslevel=4)
response = Response(
content=compressed,
status_code=response.status_code,
headers=dict(response.headers)
)
response.headers["Content-Encoding"] = "gzip"
return response
print("✓ 压缩中间件配置完成")
---
4.4 错误处理
01.异常捕获
a.全局异常处理
a.功能说明
捕获所有未处理的异常,防止服务崩溃。返回友好的错误信息给客户端。记录错误日志便于排查问题。全局异常处理是服务稳定性的最后防线。
b.代码示例
---
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
import logging
app = FastAPI()
logger = logging.getLogger(__name__)
# 1. 全局异常处理器
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""全局异常处理"""
# 记录错误
logger.error(
f"未处理异常:{type(exc).__name__} - {str(exc)}",
exc_info=True
)
# 返回友好错误信息
return JSONResponse(
status_code=500,
content={
"error": "服务器内部错误",
"detail": str(exc),
"path": request.url.path
}
)
# 2. HTTP异常处理
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""HTTP异常处理"""
logger.warning(
f"HTTP异常:{exc.status_code} - {exc.detail} "
f"路径={request.url.path}"
)
return JSONResponse(
status_code=exc.status_code,
content={
"error": exc.detail,
"status_code": exc.status_code,
"path": request.url.path
}
)
# 3. 验证错误处理
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""验证错误处理"""
errors = exc.errors()
logger.warning(f"验证错误:{errors}")
return JSONResponse(
status_code=422,
content={
"error": "请求数据验证失败",
"details": errors
}
)
# 4. 自定义异常
class BusinessException(Exception):
"""业务异常"""
def __init__(self, message: str, code: int = 400):
self.message = message
self.code = code
super().__init__(self.message)
@app.exception_handler(BusinessException)
async def business_exception_handler(request: Request, exc: BusinessException):
"""业务异常处理"""
logger.info(f"业务异常:{exc.message}")
return JSONResponse(
status_code=exc.code,
content={
"error": "业务错误",
"message": exc.message,
"code": exc.code
}
)
# 使用
@app.post("/test")
async def test_endpoint():
raise BusinessException("余额不足", code=400)
# 5. 超时异常处理
import asyncio
class TimeoutException(Exception):
"""超时异常"""
pass
@app.exception_handler(TimeoutException)
async def timeout_exception_handler(request: Request, exc: TimeoutException):
"""超时异常处理"""
logger.error(f"请求超时:{request.url.path}")
return JSONResponse(
status_code=408,
content={
"error": "请求超时",
"message": "处理时间过长,请稍后重试"
}
)
@app.post("/long_task")
async def long_task():
"""可能超时的任务"""
try:
await asyncio.wait_for(
asyncio.sleep(100), # 模拟长任务
timeout=5.0
)
except asyncio.TimeoutError:
raise TimeoutException("任务超时")
# 6. 数据库异常处理
class DatabaseException(Exception):
"""数据库异常"""
pass
@app.exception_handler(DatabaseException)
async def database_exception_handler(request: Request, exc: DatabaseException):
"""数据库异常处理"""
logger.error(f"数据库错误:{str(exc)}")
return JSONResponse(
status_code=503,
content={
"error": "数据库服务不可用",
"message": "请稍后重试"
}
)
# 7. 外部API异常处理
import requests
class ExternalAPIException(Exception):
"""外部API异常"""
def __init__(self, service: str, message: str):
self.service = service
self.message = message
super().__init__(self.message)
@app.exception_handler(ExternalAPIException)
async def external_api_exception_handler(request: Request, exc: ExternalAPIException):
"""外部API异常处理"""
logger.error(f"外部服务错误:{exc.service} - {exc.message}")
return JSONResponse(
status_code=502,
content={
"error": f"{exc.service}服务不可用",
"message": exc.message
}
)
# 8. 结构化错误响应
from pydantic import BaseModel
from typing import Optional, Any
class ErrorResponse(BaseModel):
"""错误响应模型"""
error: str
message: str
code: int
path: str
timestamp: float
details: Optional[Any] = None
@app.exception_handler(Exception)
async def structured_exception_handler(request: Request, exc: Exception):
"""结构化异常处理"""
import time
error_response = ErrorResponse(
error=type(exc).__name__,
message=str(exc),
code=500,
path=request.url.path,
timestamp=time.time()
)
logger.error(
f"异常:{error_response.error} - {error_response.message}",
exc_info=True
)
return JSONResponse(
status_code=500,
content=error_response.dict()
)
# 9. 错误追踪
import traceback
import sys
@app.exception_handler(Exception)
async def traceable_exception_handler(request: Request, exc: Exception):
"""可追踪的异常处理"""
# 获取堆栈跟踪
tb = traceback.format_exception(*sys.exc_info())
error_id = str(time.time())
logger.error(
f"错误ID={error_id}\n"
f"路径={request.url.path}\n"
f"异常={''.join(tb)}"
)
return JSONResponse(
status_code=500,
content={
"error": "服务器错误",
"error_id": error_id,
"message": "错误已记录,请联系技术支持"
}
)
# 10. 信创环境异常处理
class XinchuangException(Exception):
"""信创环境异常"""
def __init__(self, component: str, message: str):
self.component = component
self.message = message
super().__init__(self.message)
@app.exception_handler(XinchuangException)
async def xinchuang_exception_handler(request: Request, exc: XinchuangException):
"""信创异常处理"""
logger.error(
f"信创组件错误:{exc.component} - {exc.message}"
)
# 根据组件类型返回不同信息
component_messages = {
"dameng": "达梦数据库连接失败",
"ollama": "本地模型服务不可用",
"kylin": "系统服务错误"
}
return JSONResponse(
status_code=503,
content={
"error": "信创环境错误",
"component": exc.component,
"message": component_messages.get(exc.component, exc.message),
"suggestion": "请检查信创组件状态或联系管理员"
}
)
# 使用示例
@app.post("/xinchuang/query")
async def xinchuang_query():
"""信创查询"""
try:
# 连接达梦数据库
import dmPython
conn = dmPython.connect("...")
except Exception as e:
raise XinchuangException("dameng", str(e))
print("✓ 全局异常处理配置完成")
---
b.错误恢复
a.功能说明
实现错误恢复机制,提高服务容错能力。自动重试失败的操作。使用降级策略保障基本功能。错误恢复是构建高可用服务的关键。
b.代码示例
---
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import asyncio
import logging
app = FastAPI()
logger = logging.getLogger(__name__)
# 1. 自动重试中间件
class RetryMiddleware(BaseHTTPMiddleware):
"""自动重试中间件"""
def __init__(self, app, max_retries=3, retry_delay=1.0):
super().__init__(app)
self.max_retries = max_retries
self.retry_delay = retry_delay
async def dispatch(self, request: Request, call_next):
for attempt in range(self.max_retries):
try:
response = await call_next(request)
# 检查状态码
if response.status_code < 500:
return response
# 5xx错误,重试
if attempt < self.max_retries - 1:
logger.warning(
f"服务器错误,重试{attempt+1}/{self.max_retries-1}..."
)
await asyncio.sleep(self.retry_delay * (attempt + 1))
else:
return response
except Exception as e:
if attempt < self.max_retries - 1:
logger.error(f"请求失败,重试{attempt+1}:{str(e)}")
await asyncio.sleep(self.retry_delay * (attempt + 1))
else:
raise
# 所有重试失败
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=503,
content={"error": "服务暂时不可用,请稍后重试"}
)
app.add_middleware(RetryMiddleware, max_retries=3, retry_delay=1.0)
# 2. 熔断器
from collections import deque
from datetime import datetime, timedelta
class CircuitBreaker:
"""熔断器"""
def __init__(
self,
failure_threshold=5,
recovery_timeout=60,
expected_exception=Exception
):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.failures = deque()
self.state = "closed" # closed, open, half_open
self.opened_at = None
def call(self, func, *args, **kwargs):
"""调用函数"""
if self.state == "open":
# 检查是否可以尝试恢复
if datetime.now() - self.opened_at > timedelta(seconds=self.recovery_timeout):
self.state = "half_open"
logger.info("熔断器进入半开状态")
else:
raise Exception("服务熔断中,请稍后重试")
try:
result = func(*args, **kwargs)
# 成功,重置
if self.state == "half_open":
self.state = "closed"
self.failures.clear()
logger.info("熔断器恢复关闭状态")
return result
except self.expected_exception as e:
self._record_failure()
raise
def _record_failure(self):
"""记录失败"""
now = datetime.now()
self.failures.append(now)
# 清理过期失败记录(1分钟前)
while self.failures and now - self.failures[0] > timedelta(minutes=1):
self.failures.popleft()
# 检查是否应该打开熔断器
if len(self.failures) >= self.failure_threshold:
self.state = "open"
self.opened_at = now
logger.warning(f"熔断器打开:失败次数={len(self.failures)}")
# 使用熔断器
circuit_breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=60)
@app.post("/protected_endpoint")
async def protected_endpoint():
"""受保护的端点"""
try:
return circuit_breaker.call(some_risky_function)
except Exception as e:
return {"error": str(e)}
# 3. 降级策略
class FallbackMiddleware(BaseHTTPMiddleware):
"""降级中间件"""
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
# 检查是否需要降级
if response.status_code >= 500:
# 返回降级响应
return await self.get_fallback_response(request)
return response
except Exception as e:
logger.error(f"服务错误,使用降级:{str(e)}")
return await self.get_fallback_response(request)
async def get_fallback_response(self, request: Request):
"""获取降级响应"""
from fastapi.responses import JSONResponse
# 根据路径返回不同的降级响应
if "/chat" in request.url.path:
return JSONResponse(
content={
"response": "抱歉,服务暂时不可用,请稍后重试。",
"fallback": True
}
)
return JSONResponse(
status_code=503,
content={"error": "服务降级中"}
)
app.add_middleware(FallbackMiddleware)
# 4. 缓存降级
from functools import lru_cache
class CachedFallbackMiddleware(BaseHTTPMiddleware):
"""缓存降级中间件"""
def __init__(self, app):
super().__init__(app)
self.cache = {}
async def dispatch(self, request: Request, call_next):
cache_key = f"{request.method}:{request.url.path}"
try:
response = await call_next(request)
# 缓存成功响应
if response.status_code == 200 and hasattr(response, "body"):
self.cache[cache_key] = response.body
return response
except Exception as e:
# 使用缓存响应
if cache_key in self.cache:
logger.warning(f"服务错误,使用缓存:{str(e)}")
from fastapi.responses import Response
return Response(
content=self.cache[cache_key],
headers={"X-From-Cache": "true"}
)
raise
# 5. 优雅降级
class GracefulDegradationMiddleware(BaseHTTPMiddleware):
"""优雅降级中间件"""
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
return response
except Exception as e:
logger.error(f"错误,优雅降级:{str(e)}")
# 返回简化的响应
from fastapi.responses import JSONResponse
return JSONResponse(
content={
"success": False,
"message": "服务部分功能不可用,已提供基础响应",
"data": self.get_basic_response(request)
}
)
def get_basic_response(self, request: Request):
"""获取基础响应"""
if "/chat" in request.url.path:
return {"response": "系统繁忙,请稍后重试"}
return {}
# 6. 健康检查与自动恢复
class HealthCheckMiddleware(BaseHTTPMiddleware):
"""健康检查中间件"""
def __init__(self, app):
super().__init__(app)
self.healthy = True
self.failure_count = 0
self.max_failures = 3
async def dispatch(self, request: Request, call_next):
# 如果不健康,返回503
if not self.healthy:
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=503,
content={"error": "服务不健康,正在恢复"}
)
try:
response = await call_next(request)
# 成功,重置失败计数
self.failure_count = 0
return response
except Exception as e:
self.failure_count += 1
# 连续失败,标记为不健康
if self.failure_count >= self.max_failures:
self.healthy = False
logger.error("服务标记为不健康")
# 启动恢复任务
asyncio.create_task(self.recovery_task())
raise
async def recovery_task(self):
"""恢复任务"""
logger.info("开始恢复任务")
# 等待一段时间
await asyncio.sleep(10)
# 尝试恢复
self.healthy = True
self.failure_count = 0
logger.info("服务已恢复健康")
# 7. 信创环境错误恢复
class XinchuangRecoveryMiddleware(BaseHTTPMiddleware):
"""信创环境错误恢复中间件"""
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
return response
except Exception as e:
error_msg = str(e).lower()
# 达梦数据库错误恢复
if "dameng" in error_msg or "dm" in error_msg:
logger.error("达梦数据库错误,尝试重连")
# 重试连接
try:
import dmPython
conn = dmPython.connect(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236"
)
# 重试请求...
except:
pass
# Ollama错误恢复
elif "ollama" in error_msg:
logger.error("Ollama服务错误,切换到备用模型")
# 切换到备用LLM
from fastapi.responses import JSONResponse
return JSONResponse(
content={
"response": "当前模型不可用,请稍后重试",
"fallback": True
}
)
raise
print("✓ 错误恢复配置完成")
---
02.监控告警
a.错误统计
a.功能说明
统计各类错误的发生频率,识别系统问题。按错误类型、端点、时间段分类统计。生成错误报告,指导系统优化。错误统计是监控体系的重要组成部分。
b.代码示例
---
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from collections import defaultdict
from datetime import datetime, timedelta
import logging
app = FastAPI()
logger = logging.getLogger(__name__)
# 1. 错误统计中间件
class ErrorStatsMiddleware(BaseHTTPMiddleware):
"""错误统计中间件"""
def __init__(self, app):
super().__init__(app)
self.error_stats = {
"total_errors": 0,
"by_type": defaultdict(int),
"by_endpoint": defaultdict(int),
"by_status_code": defaultdict(int),
"recent_errors": []
}
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
# 统计错误状态码
if response.status_code >= 400:
self._record_error(
error_type="HTTPError",
endpoint=request.url.path,
status_code=response.status_code,
message=f"HTTP {response.status_code}"
)
return response
except Exception as e:
self._record_error(
error_type=type(e).__name__,
endpoint=request.url.path,
status_code=500,
message=str(e)
)
raise
def _record_error(self, error_type: str, endpoint: str, status_code: int, message: str):
"""记录错误"""
self.error_stats["total_errors"] += 1
self.error_stats["by_type"][error_type] += 1
self.error_stats["by_endpoint"][endpoint] += 1
self.error_stats["by_status_code"][status_code] += 1
# 记录最近错误
error_record = {
"timestamp": datetime.now().isoformat(),
"type": error_type,
"endpoint": endpoint,
"status_code": status_code,
"message": message
}
self.error_stats["recent_errors"].append(error_record)
# 保持最近100条
if len(self.error_stats["recent_errors"]) > 100:
self.error_stats["recent_errors"] = self.error_stats["recent_errors"][-100:]
error_stats = ErrorStatsMiddleware(app)
app.add_middleware(lambda app: error_stats)
@app.get("/metrics/errors")
def get_error_metrics():
"""获取错误指标"""
return error_stats.error_stats
# 2. 时间窗口错误统计
from collections import deque
class TimeWindowErrorStats(BaseHTTPMiddleware):
"""时间窗口错误统计"""
def __init__(self, app, window_minutes=60):
super().__init__(app)
self.window = timedelta(minutes=window_minutes)
self.errors = deque() # (timestamp, error_type, endpoint)
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
if response.status_code >= 400:
self._add_error("HTTPError", request.url.path)
return response
except Exception as e:
self._add_error(type(e).__name__, request.url.path)
raise
def _add_error(self, error_type: str, endpoint: str):
"""添加错误"""
now = datetime.now()
# 添加新错误
self.errors.append((now, error_type, endpoint))
# 清理过期错误
while self.errors and now - self.errors[0][0] > self.window:
self.errors.popleft()
def get_stats(self):
"""获取统计"""
if not self.errors:
return {}
# 按类型统计
by_type = defaultdict(int)
by_endpoint = defaultdict(int)
for _, error_type, endpoint in self.errors:
by_type[error_type] += 1
by_endpoint[endpoint] += 1
return {
"window_minutes": self.window.seconds // 60,
"total_errors": len(self.errors),
"by_type": dict(by_type),
"by_endpoint": dict(by_endpoint),
"error_rate": len(self.errors) / (self.window.seconds / 60) # 每分钟
}
# 3. 错误趋势分析
class ErrorTrendAnalyzer:
"""错误趋势分析器"""
def __init__(self):
self.hourly_errors = defaultdict(int)
def record_error(self):
"""记录错误"""
hour = datetime.now().replace(minute=0, second=0, microsecond=0)
self.hourly_errors[hour] += 1
def get_trend(self, hours=24):
"""获取趋势"""
now = datetime.now().replace(minute=0, second=0, microsecond=0)
trend = []
for i in range(hours):
hour = now - timedelta(hours=i)
count = self.hourly_errors.get(hour, 0)
trend.append({
"hour": hour.isoformat(),
"errors": count
})
return list(reversed(trend))
trend_analyzer = ErrorTrendAnalyzer()
@app.get("/metrics/error_trend")
def get_error_trend():
"""获取错误趋势"""
return trend_analyzer.get_trend(hours=24)
# 4. 错误告警
class ErrorAlertMiddleware(BaseHTTPMiddleware):
"""错误告警中间件"""
def __init__(self, app, threshold=10, window_seconds=60):
super().__init__(app)
self.threshold = threshold
self.window = timedelta(seconds=window_seconds)
self.recent_errors = deque()
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
if response.status_code >= 500:
await self._check_alert()
return response
except Exception as e:
await self._check_alert()
raise
async def _check_alert(self):
"""检查是否需要告警"""
now = datetime.now()
# 添加错误
self.recent_errors.append(now)
# 清理过期
while self.recent_errors and now - self.recent_errors[0] > self.window:
self.recent_errors.popleft()
# 检查阈值
if len(self.recent_errors) >= self.threshold:
await self._send_alert()
async def _send_alert(self):
"""发送告警"""
logger.critical(
f"错误告警:{self.window.seconds}秒内发生{len(self.recent_errors)}个错误"
)
# 实际应发送到钉钉、企业微信等
# await send_to_dingtalk(...)
# 5. 信创环境错误监控
class XinchuangErrorMonitor(BaseHTTPMiddleware):
"""信创环境错误监控"""
def __init__(self, app):
super().__init__(app)
self.component_errors = {
"dameng": 0,
"ollama": 0,
"kylin": 0,
"other": 0
}
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
return response
except Exception as e:
# 分类错误
error_msg = str(e).lower()
if "dameng" in error_msg or "dm" in error_msg:
self.component_errors["dameng"] += 1
elif "ollama" in error_msg:
self.component_errors["ollama"] += 1
elif "kylin" in error_msg:
self.component_errors["kylin"] += 1
else:
self.component_errors["other"] += 1
logger.error(f"信创组件错误统计:{self.component_errors}")
raise
xc_monitor = XinchuangErrorMonitor(app)
app.add_middleware(lambda app: xc_monitor)
@app.get("/metrics/xinchuang_errors")
def get_xinchuang_error_metrics():
"""获取信创错误指标"""
return xc_monitor.component_errors
print("✓ 错误统计配置完成")
---
5 Playground
5.1 交互界面
01.Playground功能
a.自动生成
a.功能说明
LangServe自动为每个Chain生成交互式Playground界面。提供Web UI测试Chain功能,无需编写客户端代码。支持输入表单、实时响应显示。Playground是快速测试和演示的利器。
b.代码示例
---
from fastapi import FastAPI
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
app = FastAPI(
title="LangServe Playground Demo",
description="带Playground的LangServe服务"
)
# 1. 基础Playground
llm = ChatOpenAI(model="gpt-3.5-turbo")
prompt = ChatPromptTemplate.from_template("回答:{question}")
chain = prompt | llm
# 添加路由,自动生成Playground
add_routes(app, chain, path="/chat")
# 访问:http://localhost:8000/chat/playground
# 会看到自动生成的交互界面
# 2. 自定义Playground标题
add_routes(
app,
chain,
path="/qa",
playground_type="default", # default或custom
# 配置Playground元数据
config_keys=["configurable"]
)
# 3. 多个Chain的Playground
# 翻译Chain
translate_prompt = ChatPromptTemplate.from_template(
"将'{text}'翻译为{target_lang}"
)
translate_chain = translate_prompt | llm
add_routes(app, translate_chain, path="/translate")
# 摘要Chain
summary_prompt = ChatPromptTemplate.from_template("总结:{content}")
summary_chain = summary_prompt | llm
add_routes(app, summary_chain, path="/summary")
# 现在有多个Playground:
# /chat/playground
# /translate/playground
# /summary/playground
# 4. 带示例的Playground
from pydantic import BaseModel, Field
class QuestionInput(BaseModel):
"""问题输入"""
question: str = Field(
...,
description="用户问题",
example="什么是LangServe?"
)
context: str = Field(
default="",
description="上下文",
example="LangServe是LangChain的部署工具"
)
# Pydantic模型的example会显示在Playground中
# 5. 禁用Playground
add_routes(
app,
chain,
path="/private_chat",
enable_playground=False # 禁用Playground
)
# /private_chat/playground 将返回404
# 6. Playground主页
@app.get("/")
def playground_home():
"""Playground主页"""
return {
"message": "LangServe Playground",
"playgrounds": [
{"name": "Chat", "url": "/chat/playground"},
{"name": "Translate", "url": "/translate/playground"},
{"name": "Summary", "url": "/summary/playground"}
]
}
# 7. 自定义Playground页面
from fastapi.responses import HTMLResponse
@app.get("/custom_playground", response_class=HTMLResponse)
def custom_playground():
"""自定义Playground页面"""
html_content = """
<!DOCTYPE html>
<html>
<head>
<title>自定义Playground</title>
<style>
body { font-family: Arial; max-width: 800px; margin: 50px auto; }
textarea { width: 100%; height: 100px; margin: 10px 0; }
button { padding: 10px 20px; background: #007bff; color: white; border: none; cursor: pointer; }
#output { margin-top: 20px; padding: 10px; border: 1px solid #ccc; }
</style>
</head>
<body>
<h1>LangServe Playground</h1>
<textarea id="input" placeholder="输入你的问题..."></textarea>
<button onclick="sendRequest()">发送</button>
<div id="output"></div>
<script>
async function sendRequest() {
const input = document.getElementById('input').value;
const output = document.getElementById('output');
try {
const response = await fetch('/chat/invoke', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({input: {question: input}})
});
const data = await response.json();
output.textContent = JSON.stringify(data, null, 2);
} catch (error) {
output.textContent = '错误:' + error.message;
}
}
</script>
</body>
</html>
"""
return html_content
# 8. Playground统计
playground_stats = {"visits": 0, "requests": 0}
@app.middleware("http")
async def track_playground(request, call_next):
"""追踪Playground使用"""
if "/playground" in request.url.path:
playground_stats["visits"] += 1
if "/invoke" in request.url.path or "/stream" in request.url.path:
playground_stats["requests"] += 1
return await call_next(request)
@app.get("/metrics/playground")
def get_playground_metrics():
"""获取Playground指标"""
return playground_stats
# 9. 信创环境Playground
# 本地模型Playground
from langchain.llms import Ollama
local_llm = Ollama(model="qwen:7b", base_url="http://localhost:11434")
local_prompt = ChatPromptTemplate.from_template("问题:{question}")
local_chain = local_prompt | local_llm
add_routes(
app,
local_chain,
path="/local_chat",
playground_type="default"
)
# 信创Playground说明页
@app.get("/xinchuang", response_class=HTMLResponse)
def xinchuang_playground_info():
"""信创Playground说明"""
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>信创环境Playground</title>
</head>
<body>
<h1>信创环境LangServe Playground</h1>
<p>系统信息:</p>
<ul>
<li>操作系统:麒麟OS</li>
<li>数据库:达梦DM8</li>
<li>LLM:Ollama + Qwen</li>
</ul>
<p>可用Playground:</p>
<ul>
<li><a href="/local_chat/playground">本地聊天</a></li>
<li><a href="/chat/playground">标准聊天</a></li>
</ul>
</body>
</html>
"""
# 10. 启动服务
if __name__ == "__main__":
import uvicorn
print("启动LangServe服务...")
print("Playground地址:")
print(" - http://localhost:8000/chat/playground")
print(" - http://localhost:8000/translate/playground")
print(" - http://localhost:8000/summary/playground")
print(" - http://localhost:8000/docs (API文档)")
uvicorn.run(app, host="0.0.0.0", port=8000)
print("✓ Playground自动生成配置完成")
---
b.UI定制
a.功能说明
自定义Playground界面的外观和行为。添加品牌元素、自定义样式。配置输入提示、示例数据。UI定制提升用户体验和品牌形象。
b.代码示例
---
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
app = FastAPI()
# 1. 自定义HTML Playground
@app.get("/custom_ui", response_class=HTMLResponse)
def custom_ui_playground():
"""自定义UI的Playground"""
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>AI助手 - Playground</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.container {
background: white;
border-radius: 10px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
width: 90%;
max-width: 800px;
padding: 30px;
}
h1 {
color: #667eea;
margin-bottom: 20px;
text-align: center;
}
.input-group {
margin-bottom: 20px;
}
label {
display: block;
margin-bottom: 5px;
color: #555;
font-weight: bold;
}
textarea {
width: 100%;
padding: 10px;
border: 2px solid #e0e0e0;
border-radius: 5px;
font-size: 14px;
resize: vertical;
}
textarea:focus {
outline: none;
border-color: #667eea;
}
button {
width: 100%;
padding: 12px;
background: #667eea;
color: white;
border: none;
border-radius: 5px;
font-size: 16px;
cursor: pointer;
transition: background 0.3s;
}
button:hover {
background: #5568d3;
}
button:disabled {
background: #ccc;
cursor: not-allowed;
}
#output {
margin-top: 20px;
padding: 15px;
background: #f5f5f5;
border-radius: 5px;
min-height: 100px;
white-space: pre-wrap;
font-family: 'Courier New', monospace;
display: none;
}
.loading {
text-align: center;
color: #667eea;
}
.example {
margin-top: 10px;
font-size: 12px;
color: #888;
}
.example-btn {
background: #e0e0e0;
color: #555;
padding: 5px 10px;
border-radius: 3px;
cursor: pointer;
display: inline-block;
margin-right: 5px;
margin-top: 5px;
}
.example-btn:hover {
background: #d0d0d0;
}
</style>
</head>
<body>
<div class="container">
<h1>🤖 AI助手 Playground</h1>
<div class="input-group">
<label for="input">输入你的问题:</label>
<textarea id="input" rows="4" placeholder="例如:解释什么是人工智能"></textarea>
<div class="example">
<span>示例问题:</span>
<span class="example-btn" onclick="setExample('什么是LangServe?')">LangServe介绍</span>
<span class="example-btn" onclick="setExample('如何部署AI服务?')">部署指南</span>
<span class="example-btn" onclick="setExample('解释向量数据库')">向量数据库</span>
</div>
</div>
<button id="submitBtn" onclick="sendRequest()">发送</button>
<div id="output"></div>
</div>
<script>
function setExample(text) {
document.getElementById('input').value = text;
}
async function sendRequest() {
const input = document.getElementById('input').value;
const output = document.getElementById('output');
const btn = document.getElementById('submitBtn');
if (!input.trim()) {
alert('请输入问题');
return;
}
// 显示加载状态
output.style.display = 'block';
output.innerHTML = '<div class="loading">⏳ 正在思考...</div>';
btn.disabled = true;
btn.textContent = '处理中...';
try {
const response = await fetch('/chat/invoke', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({input: {question: input}})
});
const data = await response.json();
// 显示结果
if (data.output && data.output.content) {
output.innerHTML = '✅ <strong>回答:</strong>\\n\\n' + data.output.content;
} else {
output.innerHTML = JSON.stringify(data, null, 2);
}
} catch (error) {
output.innerHTML = '❌ <strong>错误:</strong>\\n' + error.message;
} finally {
btn.disabled = false;
btn.textContent = '发送';
}
}
// Enter键发送(Ctrl+Enter)
document.getElementById('input').addEventListener('keydown', function(e) {
if (e.key === 'Enter' && e.ctrlKey) {
sendRequest();
}
});
</script>
</body>
</html>
"""
# 2. 流式UI
@app.get("/stream_ui", response_class=HTMLResponse)
def stream_ui_playground():
"""流式输出UI"""
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>流式 Playground</title>
<style>
body { font-family: Arial; max-width: 900px; margin: 50px auto; padding: 20px; }
h1 { color: #333; }
textarea { width: 100%; padding: 10px; border: 2px solid #ddd; border-radius: 5px; }
button { padding: 10px 20px; background: #28a745; color: white; border: none; border-radius: 5px; cursor: pointer; margin-top: 10px; }
#output { margin-top: 20px; padding: 15px; border: 1px solid #ddd; border-radius: 5px; min-height: 150px; background: #f9f9f9; white-space: pre-wrap; }
.cursor { animation: blink 1s infinite; }
@keyframes blink { 0%, 50% { opacity: 1; } 51%, 100% { opacity: 0; } }
</style>
</head>
<body>
<h1>🔄 流式输出 Playground</h1>
<textarea id="input" rows="3" placeholder="输入问题..."></textarea>
<button onclick="streamRequest()">流式发送</button>
<div id="output"></div>
<script>
async function streamRequest() {
const input = document.getElementById('input').value;
const output = document.getElementById('output');
output.innerHTML = '<span class="cursor">▊</span>';
try {
const response = await fetch('/chat/stream', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({input: {question: input}})
});
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullText = '';
while (true) {
const {done, value} = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
try {
const data = JSON.parse(line.slice(6));
if (data.content) {
fullText += data.content;
output.textContent = fullText;
}
} catch (e) {}
}
}
}
} catch (error) {
output.textContent = '错误:' + error.message;
}
}
</script>
</body>
</html>
"""
# 3. 多功能Playground
@app.get("/multi_feature", response_class=HTMLResponse)
def multi_feature_playground():
"""多功能Playground"""
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>多功能 Playground</title>
<style>
body { font-family: Arial; max-width: 1200px; margin: 20px auto; padding: 20px; }
.tabs { display: flex; border-bottom: 2px solid #ddd; margin-bottom: 20px; }
.tab { padding: 10px 20px; cursor: pointer; background: #f0f0f0; margin-right: 5px; border-radius: 5px 5px 0 0; }
.tab.active { background: #007bff; color: white; }
.tab-content { display: none; }
.tab-content.active { display: block; }
textarea { width: 100%; padding: 10px; margin: 10px 0; }
button { padding: 10px 20px; background: #007bff; color: white; border: none; border-radius: 5px; cursor: pointer; }
.output { margin-top: 15px; padding: 15px; border: 1px solid #ddd; border-radius: 5px; background: #f9f9f9; min-height: 100px; }
</style>
</head>
<body>
<h1>🚀 多功能 Playground</h1>
<div class="tabs">
<div class="tab active" onclick="showTab('chat')">聊天</div>
<div class="tab" onclick="showTab('translate')">翻译</div>
<div class="tab" onclick="showTab('summary')">摘要</div>
</div>
<div id="chat" class="tab-content active">
<h2>聊天</h2>
<textarea id="chat-input" rows="3" placeholder="输入你的问题..."></textarea>
<button onclick="chat()">发送</button>
<div id="chat-output" class="output"></div>
</div>
<div id="translate" class="tab-content">
<h2>翻译</h2>
<textarea id="translate-input" rows="3" placeholder="输入要翻译的文本..."></textarea>
<select id="target-lang">
<option value="zh">中文</option>
<option value="en">英文</option>
<option value="ja">日文</option>
</select>
<button onclick="translate()">翻译</button>
<div id="translate-output" class="output"></div>
</div>
<div id="summary" class="tab-content">
<h2>摘要</h2>
<textarea id="summary-input" rows="5" placeholder="输入要总结的内容..."></textarea>
<button onclick="summary()">生成摘要</button>
<div id="summary-output" class="output"></div>
</div>
<script>
function showTab(tabName) {
document.querySelectorAll('.tab').forEach(t => t.classList.remove('active'));
document.querySelectorAll('.tab-content').forEach(c => c.classList.remove('active'));
event.target.classList.add('active');
document.getElementById(tabName).classList.add('active');
}
async function chat() {
const input = document.getElementById('chat-input').value;
const output = document.getElementById('chat-output');
output.textContent = '处理中...';
try {
const response = await fetch('/chat/invoke', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({input: {question: input}})
});
const data = await response.json();
output.textContent = data.output.content || JSON.stringify(data);
} catch (error) {
output.textContent = '错误:' + error.message;
}
}
async function translate() {
const input = document.getElementById('translate-input').value;
const targetLang = document.getElementById('target-lang').value;
const output = document.getElementById('translate-output');
output.textContent = '翻译中...';
try {
const response = await fetch('/translate/invoke', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({input: {text: input, target_lang: targetLang}})
});
const data = await response.json();
output.textContent = data.output.content || JSON.stringify(data);
} catch (error) {
output.textContent = '错误:' + error.message;
}
}
async function summary() {
const input = document.getElementById('summary-input').value;
const output = document.getElementById('summary-output');
output.textContent = '生成中...';
try {
const response = await fetch('/summary/invoke', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({input: {content: input}})
});
const data = await response.json();
output.textContent = data.output.content || JSON.stringify(data);
} catch (error) {
output.textContent = '错误:' + error.message;
}
}
</script>
</body>
</html>
"""
# 4. 信创环境自定义UI
@app.get("/xinchuang_ui", response_class=HTMLResponse)
def xinchuang_ui():
"""信创环境定制UI"""
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>信创AI助手</title>
<style>
body {
font-family: "Microsoft YaHei", Arial, sans-serif;
background: #f5f5f5;
margin: 0;
padding: 20px;
}
.header {
background: linear-gradient(135deg, #c31432 0%, #240b36 100%);
color: white;
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
text-align: center;
}
.container {
max-width: 900px;
margin: 0 auto;
background: white;
padding: 30px;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.system-info {
background: #e3f2fd;
padding: 15px;
border-radius: 5px;
margin-bottom: 20px;
}
textarea {
width: 100%;
padding: 12px;
border: 2px solid #ddd;
border-radius: 5px;
font-size: 14px;
}
button {
width: 100%;
padding: 12px;
background: #c31432;
color: white;
border: none;
border-radius: 5px;
font-size: 16px;
cursor: pointer;
margin-top: 10px;
}
button:hover {
background: #a01025;
}
#output {
margin-top: 20px;
padding: 15px;
background: #f9f9f9;
border-radius: 5px;
border-left: 4px solid #c31432;
display: none;
}
</style>
</head>
<body>
<div class="header">
<h1>🇨🇳 信创环境 AI 助手</h1>
<p>基于国产化技术栈的智能问答系统</p>
</div>
<div class="container">
<div class="system-info">
<strong>系统信息:</strong><br>
📌 操作系统:麒麟OS (Kylin)<br>
📌 数据库:达梦DM8<br>
📌 LLM:Ollama + Qwen-7B (本地部署)<br>
📌 服务状态:<span style="color: green;">● 正常</span>
</div>
<h3>请输入您的问题:</h3>
<textarea id="input" rows="4" placeholder="例如:介绍一下信创技术体系"></textarea>
<button onclick="sendRequest()">🚀 提交问题</button>
<div id="output"></div>
</div>
<script>
async function sendRequest() {
const input = document.getElementById('input').value;
const output = document.getElementById('output');
if (!input.trim()) {
alert('请输入问题');
return;
}
output.style.display = 'block';
output.innerHTML = '⏳ 本地模型正在处理...';
try {
const response = await fetch('/local_chat/invoke', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({input: {question: input}})
});
const data = await response.json();
if (data.output) {
output.innerHTML = '<strong>📝 AI回答:</strong><br><br>' + data.output;
} else {
output.innerHTML = JSON.stringify(data, null, 2);
}
} catch (error) {
output.innerHTML = '❌ <strong>错误:</strong>' + error.message;
}
}
</script>
</body>
</html>
"""
print("✓ UI定制配置完成")
---
02.测试功能
a.输入测试
a.功能说明
使用Playground测试各种输入场景,验证Chain功能。测试边界情况、异常输入。快速迭代和调试。Playground是开发过程中的重要测试工具。
b.代码示例
---
# Playground测试场景示例
# 1. 正常输入测试
# 在Playground中输入:
# {"question": "什么是LangServe?"}
# 预期:返回LangServe的介绍
# 2. 空输入测试
# {"question": ""}
# 预期:验证错误或返回提示
# 3. 超长输入测试
# {"question": "A" * 10000}
# 预期:检查是否有长度限制
# 4. 特殊字符测试
# {"question": "测试<script>alert('xss')</script>"}
# 预期:正确处理HTML/JS
# 5. 多语言测试
# {"question": "こんにちは"} # 日语
# {"question": "안녕하세요"} # 韩语
# 预期:正确处理非英文输入
# 6. JSON格式测试
# {"question": "测试\\n换行\\t制表符"}
# 预期:正确处理转义字符
# 7. 配置参数测试
# {
# "input": {"question": "测试"},
# "config": {"configurable": {"temperature": 0.9}}
# }
# 预期:使用指定配置
# 8. 批量测试脚本
import requests
import json
def test_playground_inputs():
"""测试Playground各种输入"""
base_url = "http://localhost:8000/chat/invoke"
test_cases = [
{
"name": "正常输入",
"input": {"question": "你好"},
"expected_status": 200
},
{
"name": "空输入",
"input": {"question": ""},
"expected_status": 200 # 或422,取决于验证规则
},
{
"name": "超长输入",
"input": {"question": "A" * 5000},
"expected_status": 200
},
{
"name": "特殊字符",
"input": {"question": "测试@#$%^&*()"},
"expected_status": 200
},
{
"name": "多行输入",
"input": {"question": "第一行\\n第二行\\n第三行"},
"expected_status": 200
}
]
results = []
for test in test_cases:
print(f"测试:{test['name']}")
try:
response = requests.post(
base_url,
json={"input": test["input"]},
timeout=30
)
result = {
"test_name": test["name"],
"status_code": response.status_code,
"success": response.status_code == test["expected_status"],
"response": response.json() if response.ok else response.text
}
results.append(result)
if result["success"]:
print(f" ✓ 通过")
else:
print(f" ✗ 失败:状态码{response.status_code}")
except Exception as e:
print(f" ✗ 错误:{e}")
results.append({
"test_name": test["name"],
"error": str(e),
"success": False
})
# 输出测试报告
print("\\n=== 测试报告 ===")
passed = sum(1 for r in results if r.get("success"))
total = len(results)
print(f"通过:{passed}/{total}")
return results
# 运行测试
# test_results = test_playground_inputs()
print("✓ Playground测试功能配置完成")
---
5.2 测试调试
01.调试工具
a.日志输出
a.功能说明
配置详细的日志输出,追踪请求处理流程。记录输入输出、中间结果、错误信息。使用不同日志级别(DEBUG、INFO、WARNING、ERROR)。日志是调试问题的第一工具。
b.代码示例
---
import logging
from fastapi import FastAPI, Request
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
# 1. 配置日志
logging.basicConfig(
level=logging.DEBUG, # 开发环境使用DEBUG
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("langserve_debug.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
app = FastAPI()
# 2. 记录Chain执行
from langchain.callbacks import StdOutCallbackHandler
from langchain.schema.runnable import RunnableConfig
llm = ChatOpenAI(model="gpt-3.5-turbo")
prompt = ChatPromptTemplate.from_template("回答:{question}")
chain = prompt | llm
# 使用回调记录执行
@app.post("/debug_chat")
def debug_chat(question: str):
"""带调试的聊天"""
logger.debug(f"收到问题:{question}")
# 使用回调
config = RunnableConfig(
callbacks=[StdOutCallbackHandler()]
)
try:
result = chain.invoke({"question": question}, config=config)
logger.debug(f"生成回答:{result.content}")
return {"response": result.content}
except Exception as e:
logger.error(f"执行失败:{e}", exc_info=True)
raise
# 3. 中间件日志
@app.middleware("http")
async def debug_middleware(request: Request, call_next):
"""调试中间件"""
# 记录请求
logger.debug(f"请求:{request.method} {request.url}")
logger.debug(f"Headers: {dict(request.headers)}")
# 如果是POST,记录body(仅调试环境)
if request.method == "POST":
# 注意:读取body后需要重新构建request
body = await request.body()
logger.debug(f"Body: {body.decode()}")
response = await call_next(request)
# 记录响应
logger.debug(f"响应状态:{response.status_code}")
logger.debug(f"响应头:{dict(response.headers)}")
return response
# 4. 详细错误追踪
import traceback
import sys
@app.exception_handler(Exception)
async def debug_exception_handler(request: Request, exc: Exception):
"""调试异常处理"""
# 获取完整堆栈
tb_lines = traceback.format_exception(*sys.exc_info())
tb_text = ''.join(tb_lines)
logger.error(
f"异常详情:\n"
f"类型:{type(exc).__name__}\n"
f"消息:{str(exc)}\n"
f"路径:{request.url.path}\n"
f"堆栈:\n{tb_text}"
)
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=500,
content={
"error": str(exc),
"type": type(exc).__name__,
"traceback": tb_lines
}
)
# 5. 性能调试
import time
@app.middleware("http")
async def performance_debug(request: Request, call_next):
"""性能调试中间件"""
start = time.time()
logger.debug(f"[开始] {request.url.path}")
response = await call_next(request)
elapsed = time.time() - start
logger.debug(f"[完成] {request.url.path} 耗时{elapsed:.3f}秒")
# 慢请求警告
if elapsed > 1.0:
logger.warning(f"慢请求:{request.url.path} 耗时{elapsed:.3f}秒")
return response
# 6. 请求重放
class RequestRecorder:
"""请求记录器(用于重放)"""
def __init__(self):
self.recorded_requests = []
async def record(self, request: Request):
"""记录请求"""
body = await request.body()
self.recorded_requests.append({
"method": request.method,
"url": str(request.url),
"headers": dict(request.headers),
"body": body.decode()
})
def replay(self, index: int):
"""重放请求"""
if index >= len(self.recorded_requests):
return None
recorded = self.recorded_requests[index]
import requests
response = requests.request(
method=recorded["method"],
url=recorded["url"],
headers=recorded["headers"],
data=recorded["body"]
)
return response
recorder = RequestRecorder()
@app.middleware("http")
async def record_middleware(request: Request, call_next):
"""记录请求中间件"""
# 记录请求
await recorder.record(request)
response = await call_next(request)
return response
@app.get("/debug/replay/{index}")
def replay_request(index: int):
"""重放请求"""
response = recorder.replay(index)
return {"replayed": True, "response": response.json()}
# 7. 断点调试辅助
def debug_checkpoint(name: str, data: any):
"""调试检查点"""
logger.debug(f"[{name}] {type(data).__name__}: {data}")
@app.post("/debug_chain")
def debug_chain_execution(question: str):
"""调试Chain执行"""
debug_checkpoint("输入", question)
# 执行Prompt
prompt_result = prompt.invoke({"question": question})
debug_checkpoint("Prompt结果", prompt_result)
# 执行LLM
llm_result = llm.invoke(prompt_result)
debug_checkpoint("LLM结果", llm_result)
return {"response": llm_result.content}
# 8. 信创环境调试
import dmPython
from langchain.llms import Ollama
@app.post("/xinchuang_debug")
def xinchuang_debug_endpoint(question: str):
"""信创环境调试端点"""
debug_info = {
"question": question,
"steps": []
}
try:
# 1. 数据库连接
logger.debug("连接达梦数据库...")
conn = dmPython.connect(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236"
)
debug_info["steps"].append({"step": "db_connect", "status": "success"})
# 2. Ollama连接
logger.debug("连接Ollama服务...")
llm = Ollama(model="qwen:7b", base_url="http://localhost:11434")
debug_info["steps"].append({"step": "llm_connect", "status": "success"})
# 3. 执行查询
logger.debug("执行LLM查询...")
result = llm.invoke(question)
debug_info["steps"].append({"step": "llm_invoke", "status": "success"})
debug_info["result"] = result
return debug_info
except Exception as e:
logger.error(f"信创调试错误:{e}", exc_info=True)
debug_info["error"] = str(e)
debug_info["status"] = "failed"
return debug_info
print("✓ 日志调试配置完成")
---
b.性能分析
a.功能说明
使用Playground分析Chain性能,识别瓶颈。测量各步骤耗时,优化慢操作。对比不同配置的性能差异。性能分析是优化系统的科学方法。
b.代码示例
---
from fastapi import FastAPI
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
import time
import cProfile
import pstats
import io
app = FastAPI()
# 1. 计时装饰器
def timing_decorator(func):
"""计时装饰器"""
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
elapsed = time.time() - start
logger.info(f"{func.__name__} 耗时:{elapsed:.3f}秒")
return result
return wrapper
# 2. 分步计时
@app.post("/timed_chat")
def timed_chat(question: str):
"""分步计时的聊天"""
times = {}
# Prompt阶段
start = time.time()
from langchain.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_template("回答:{question}")
prompt_result = prompt.invoke({"question": question})
times["prompt"] = time.time() - start
# LLM阶段
start = time.time()
llm = ChatOpenAI()
llm_result = llm.invoke(prompt_result)
times["llm"] = time.time() - start
# 解析阶段
start = time.time()
response = llm_result.content
times["parse"] = time.time() - start
times["total"] = sum(times.values())
return {
"response": response,
"performance": times
}
# 3. 批量性能测试
def benchmark_chain(chain, test_inputs: list, iterations=10):
"""Chain性能基准测试"""
import statistics
times = []
for _ in range(iterations):
for inp in test_inputs:
start = time.time()
chain.invoke(inp)
elapsed = time.time() - start
times.append(elapsed)
return {
"total_requests": len(times),
"avg_time": statistics.mean(times),
"median_time": statistics.median(times),
"min_time": min(times),
"max_time": max(times),
"std_dev": statistics.stdev(times) if len(times) > 1 else 0
}
# 使用
@app.post("/benchmark")
def run_benchmark():
"""运行基准测试"""
llm = ChatOpenAI()
prompt = ChatPromptTemplate.from_template("回答:{question}")
chain = prompt | llm
test_inputs = [
{"question": "问题1"},
{"question": "问题2"},
{"question": "问题3"}
]
results = benchmark_chain(chain, test_inputs, iterations=5)
return results
# 4. 内存分析
import tracemalloc
@app.post("/memory_analysis")
def memory_analysis_chat(question: str):
"""内存分析的聊天"""
# 开始追踪
tracemalloc.start()
# 执行Chain
llm = ChatOpenAI()
result = llm.invoke(question)
# 获取内存统计
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
return {
"response": result.content,
"memory": {
"current_kb": current / 1024,
"peak_kb": peak / 1024
}
}
# 5. 性能profiling
@app.post("/profile_chat")
def profile_chat(question: str):
"""性能分析的聊天"""
profiler = cProfile.Profile()
profiler.enable()
# 执行Chain
llm = ChatOpenAI()
result = llm.invoke(question)
profiler.disable()
# 输出统计
s = io.StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative')
ps.print_stats(10) # 前10项
profile_output = s.getvalue()
logger.info(f"性能分析:\n{profile_output}")
return {
"response": result.content,
"profile": profile_output
}
# 6. 并发测试
import asyncio
from langserve import RemoteRunnable
async def concurrency_test(url: str, num_requests=10):
"""并发测试"""
remote_chain = RemoteRunnable(url)
start = time.time()
tasks = [
remote_chain.ainvoke({"question": f"问题{i}"})
for i in range(num_requests)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
elapsed = time.time() - start
success = sum(1 for r in results if not isinstance(r, Exception))
return {
"total_requests": num_requests,
"success": success,
"failed": num_requests - success,
"total_time": elapsed,
"requests_per_second": num_requests / elapsed
}
@app.post("/test_concurrency")
async def test_concurrency(num_requests: int = 10):
"""测试并发性能"""
result = await concurrency_test(
"http://localhost:8000/chat",
num_requests=num_requests
)
return result
# 7. 错误注入测试
@app.post("/error_injection_test")
def error_injection_test(inject_error: bool = False):
"""错误注入测试"""
logger.debug(f"错误注入:{inject_error}")
if inject_error:
# 模拟错误
raise Exception("注入的测试错误")
return {"status": "ok"}
# 8. 断言测试
@app.post("/assertion_test")
def assertion_test(question: str):
"""断言测试"""
llm = ChatOpenAI()
result = llm.invoke(question)
# 断言检查
assert result is not None, "结果不能为空"
assert hasattr(result, 'content'), "结果缺少content属性"
assert len(result.content) > 0, "回答不能为空"
logger.debug(f"断言通过,响应长度:{len(result.content)}")
return {"response": result.content}
# 9. 信创环境调试
@app.post("/xinchuang_debug_test")
def xinchuang_debug_test(question: str):
"""信创环境调试测试"""
debug_log = []
try:
# 1. 测试达梦连接
debug_log.append("测试达梦数据库连接...")
import dmPython
conn = dmPython.connect(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236"
)
debug_log.append("✓ 达梦连接成功")
# 2. 测试Ollama连接
debug_log.append("测试Ollama服务...")
from langchain.llms import Ollama
llm = Ollama(model="qwen:7b", base_url="http://localhost:11434")
debug_log.append("✓ Ollama连接成功")
# 3. 测试查询
debug_log.append(f"执行查询:{question}")
start = time.time()
result = llm.invoke(question)
elapsed = time.time() - start
debug_log.append(f"✓ 查询完成,耗时{elapsed:.3f}秒")
return {
"success": True,
"response": result,
"debug_log": debug_log,
"performance": {"query_time": elapsed}
}
except Exception as e:
debug_log.append(f"✗ 错误:{str(e)}")
logger.error(f"信创调试失败:{e}", exc_info=True)
return {
"success": False,
"error": str(e),
"debug_log": debug_log
}
# 10. 调试配置开关
import os
DEBUG_MODE = os.getenv("DEBUG", "false").lower() == "true"
if DEBUG_MODE:
logger.info("调试模式已启用")
# 启用详细日志
logging.getLogger("langchain").setLevel(logging.DEBUG)
logging.getLogger("langserve").setLevel(logging.DEBUG)
# 添加调试端点
@app.get("/debug/status")
def debug_status():
"""调试状态"""
import sys
return {
"debug_mode": DEBUG_MODE,
"python_version": sys.version,
"log_level": logging.getLevelName(logger.level)
}
print("✓ 调试工具配置完成")
---
02.测试用例
a.单元测试
a.功能说明
编写单元测试验证Chain功能的正确性。使用pytest等测试框架。测试正常情况和边界情况。单元测试是保证代码质量的基础。
b.代码示例
---
import pytest
from fastapi.testclient import TestClient
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from fastapi import FastAPI
# 1. 创建测试应用
def create_test_app():
"""创建测试应用"""
app = FastAPI()
llm = ChatOpenAI(model="gpt-3.5-turbo")
prompt = ChatPromptTemplate.from_template("回答:{question}")
chain = prompt | llm
add_routes(app, chain, path="/chat")
return app
# 2. 基础测试
def test_invoke_endpoint():
"""测试invoke端点"""
app = create_test_app()
client = TestClient(app)
response = client.post(
"/chat/invoke",
json={"input": {"question": "测试"}}
)
assert response.status_code == 200
assert "output" in response.json()
# 3. 测试不同输入
@pytest.mark.parametrize("question", [
"你好",
"What is AI?",
"123",
"测试@#$%"
])
def test_various_inputs(question):
"""测试各种输入"""
app = create_test_app()
client = TestClient(app)
response = client.post(
"/chat/invoke",
json={"input": {"question": question}}
)
assert response.status_code == 200
# 4. 测试错误情况
def test_empty_input():
"""测试空输入"""
app = create_test_app()
client = TestClient(app)
response = client.post(
"/chat/invoke",
json={"input": {"question": ""}}
)
# 应该返回错误或处理空输入
assert response.status_code in [200, 400, 422]
def test_invalid_json():
"""测试无效JSON"""
app = create_test_app()
client = TestClient(app)
response = client.post(
"/chat/invoke",
data="invalid json",
headers={"Content-Type": "application/json"}
)
assert response.status_code == 422 # Validation Error
# 5. 测试批量调用
def test_batch_endpoint():
"""测试batch端点"""
app = create_test_app()
client = TestClient(app)
response = client.post(
"/chat/batch",
json={
"inputs": [
{"question": "问题1"},
{"question": "问题2"}
]
}
)
assert response.status_code == 200
data = response.json()
assert "outputs" in data
assert len(data["outputs"]) == 2
# 6. 测试流式端点
def test_stream_endpoint():
"""测试stream端点"""
app = create_test_app()
client = TestClient(app)
with client.stream(
"POST",
"/chat/stream",
json={"input": {"question": "测试流式"}}
) as response:
assert response.status_code == 200
# 读取流式数据
chunks = []
for line in response.iter_lines():
if line:
chunks.append(line)
assert len(chunks) > 0
# 7. 集成测试
def test_full_flow():
"""测试完整流程"""
app = create_test_app()
client = TestClient(app)
# 1. 测试invoke
invoke_resp = client.post(
"/chat/invoke",
json={"input": {"question": "你好"}}
)
assert invoke_resp.status_code == 200
# 2. 测试batch
batch_resp = client.post(
"/chat/batch",
json={"inputs": [{"question": "问题1"}, {"question": "问题2"}]}
)
assert batch_resp.status_code == 200
# 3. 测试schema端点
schema_resp = client.get("/chat/input_schema")
assert schema_resp.status_code == 200
# 8. 性能回归测试
def test_performance_regression():
"""性能回归测试"""
app = create_test_app()
client = TestClient(app)
# 设定性能基准
MAX_RESPONSE_TIME = 5.0 # 5秒
start = time.time()
response = client.post(
"/chat/invoke",
json={"input": {"question": "快速问题"}}
)
elapsed = time.time() - start
assert response.status_code == 200
assert elapsed < MAX_RESPONSE_TIME, f"响应时间{elapsed:.2f}秒超过阈值{MAX_RESPONSE_TIME}秒"
# 9. Mock测试
from unittest.mock import Mock, patch
def test_with_mock():
"""使用Mock测试"""
app = create_test_app()
client = TestClient(app)
# Mock LLM响应
with patch('langchain.chat_models.ChatOpenAI.invoke') as mock_invoke:
mock_invoke.return_value = Mock(content="Mock响应")
response = client.post(
"/chat/invoke",
json={"input": {"question": "测试"}}
)
assert response.status_code == 200
# 验证使用了Mock
mock_invoke.assert_called_once()
# 10. 信创环境测试
def test_xinchuang_components():
"""信创环境组件测试"""
# 测试达梦数据库
try:
import dmPython
conn = dmPython.connect(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236"
)
cursor = conn.cursor()
cursor.execute("SELECT 1")
result = cursor.fetchone()
assert result[0] == 1, "达梦数据库查询失败"
print("✓ 达梦数据库测试通过")
except Exception as e:
pytest.fail(f"达梦数据库测试失败:{e}")
# 测试Ollama
try:
from langchain.llms import Ollama
llm = Ollama(model="qwen:7b", base_url="http://localhost:11434")
result = llm.invoke("测试")
assert result is not None, "Ollama响应为空"
print("✓ Ollama测试通过")
except Exception as e:
pytest.fail(f"Ollama测试失败:{e}")
# 运行测试
# pytest test_langserve.py -v
# 11. 测试报告
def generate_test_report():
"""生成测试报告"""
# 运行所有测试并生成报告
# pytest test_langserve.py --html=report.html --self-contained-html
print("测试报告已生成:report.html")
print("✓ 测试用例配置完成")
---
6 生产部署
6.1 Docker部署
01.容器化
a.Dockerfile编写
a.功能说明
编写Dockerfile将LangServe应用打包为容器镜像。包含Python环境、依赖安装、应用代码。配置启动命令和环境变量。容器化是现代应用部署的标准方式。
b.代码示例
---
# 1. 基础Dockerfile
# Dockerfile
"""
FROM python:3.11-slim
WORKDIR /app
# 复制依赖文件
COPY requirements.txt .
# 安装依赖
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 暴露端口
EXPOSE 8000
# 启动命令
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
"""
# requirements.txt
"""
langserve[all]>=0.0.30
langchain>=0.1.0
openai>=1.0.0
fastapi>=0.104.0
uvicorn[standard]>=0.24.0
"""
# 2. 多阶段构建
"""
# 构建阶段
FROM python:3.11 AS builder
WORKDIR /app
COPY requirements.txt .
RUN pip install --user --no-cache-dir -r requirements.txt
# 运行阶段
FROM python:3.11-slim
WORKDIR /app
# 从构建阶段复制依赖
COPY --from=builder /root/.local /root/.local
# 复制应用
COPY . .
# 环境变量
ENV PATH=/root/.local/bin:$PATH
EXPOSE 8000
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
"""
# 3. 优化的Dockerfile
"""
FROM python:3.11-slim
# 设置环境变量
ENV PYTHONUNBUFFERED=1 \\
PYTHONDONTWRITEBYTECODE=1 \\
PIP_NO_CACHE_DIR=1 \\
PIP_DISABLE_PIP_VERSION_CHECK=1
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \\
gcc \\
&& rm -rf /var/lib/apt/lists/*
# 复制并安装依赖
COPY requirements.txt .
RUN pip install -r requirements.txt
# 复制应用代码
COPY server.py .
COPY chains/ chains/
# 创建非root用户
RUN useradd -m -u 1000 langserve && \\
chown -R langserve:langserve /app
USER langserve
EXPOSE 8000
# 健康检查
HEALTHCHECK --interval=30s --timeout=3s --start-period=40s \\
CMD curl -f http://localhost:8000/health || exit 1
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
"""
# 4. 构建镜像
# 在项目目录执行
# docker build -t langserve-app:latest .
# 查看镜像
# docker images | grep langserve
# 5. 运行容器
# docker run -d \\
# --name langserve \\
# -p 8000:8000 \\
# -e OPENAI_API_KEY=your_key \\
# langserve-app:latest
# 6. docker-compose.yml
"""
version: '3.8'
services:
langserve:
build: .
ports:
- "8000:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- LANGCHAIN_TRACING_V2=true
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
volumes:
- ./data:/app/data
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 3s
retries: 3
"""
# 启动
# docker-compose up -d
# 7. 多服务docker-compose
"""
version: '3.8'
services:
langserve:
build: .
ports:
- "8000:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
depends_on:
- redis
restart: unless-stopped
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis-data:/data
restart: unless-stopped
nginx:
image: nginx:alpine
ports:
- "80:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
depends_on:
- langserve
restart: unless-stopped
volumes:
redis-data:
"""
# 8. 信创环境Dockerfile
"""
# 基于麒麟系统基础镜像
FROM registry.cn-beijing.aliyuncs.com/kylin/python:3.11
WORKDIR /app
# 使用国内镜像源
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
COPY requirements.txt .
RUN pip install -r requirements.txt
# 复制应用
COPY server.py .
COPY chains/ chains/
# 信创环境变量
ENV XINCHUANG_MODE=true \\
DB_TYPE=dameng \\
LLM_PROVIDER=ollama
EXPOSE 8000
CMD ["python3", "-m", "uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
"""
# 9. 信创docker-compose(含达梦和Ollama)
"""
version: '3.8'
services:
langserve:
build: .
ports:
- "8000:8000"
environment:
- XINCHUANG_MODE=true
- DM_HOST=dameng
- DM_PORT=5236
- OLLAMA_HOST=ollama:11434
depends_on:
- dameng
- ollama
networks:
- xinchuang-net
restart: unless-stopped
dameng:
image: dameng/dm8:latest
ports:
- "5236:5236"
environment:
- DM_PWD=SYSDBA
volumes:
- dameng-data:/opt/dmdbms/data
networks:
- xinchuang-net
restart: unless-stopped
ollama:
image: ollama/ollama:latest
ports:
- "11434:11434"
volumes:
- ollama-data:/root/.ollama
networks:
- xinchuang-net
restart: unless-stopped
networks:
xinchuang-net:
driver: bridge
volumes:
dameng-data:
ollama-data:
"""
# 启动信创环境
# docker-compose up -d
# 10. 构建脚本
# build.sh
"""
#!/bin/bash
echo "构建LangServe Docker镜像..."
# 构建镜像
docker build -t langserve-app:latest .
# 打标签
docker tag langserve-app:latest langserve-app:$(date +%Y%m%d)
echo "镜像构建完成"
docker images | grep langserve
"""
# chmod +x build.sh
# ./build.sh
print("✓ Docker容器化配置完成")
---
b.镜像优化
a.功能说明
优化Docker镜像大小和构建速度。使用轻量级基础镜像。利用构建缓存,合理组织层级。清理临时文件和缓存。镜像优化提升部署效率和资源利用。
b.代码示例
---
# 1. 精简基础镜像
"""
# 使用Alpine(更小)
FROM python:3.11-alpine
WORKDIR /app
# Alpine需要额外的构建工具
RUN apk add --no-cache gcc musl-dev linux-headers
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
"""
# 镜像大小对比:
# python:3.11 - ~1GB
# python:3.11-slim - ~150MB
# python:3.11-alpine - ~50MB
# 2. 利用缓存层
"""
FROM python:3.11-slim
WORKDIR /app
# 先复制依赖文件(变化少,缓存命中率高)
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 后复制代码(变化多)
COPY server.py .
COPY chains/ chains/
EXPOSE 8000
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
"""
# 3. 清理临时文件
"""
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
# 安装并清理
RUN pip install --no-cache-dir -r requirements.txt && \\
rm -rf /root/.cache/pip && \\
find /usr/local -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
COPY . .
EXPOSE 8000
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
"""
# 4. .dockerignore
# .dockerignore
"""
__pycache__
*.pyc
*.pyo
*.pyd
.Python
env/
venv/
.git
.gitignore
.vscode
.idea
*.log
*.md
tests/
docs/
.env
"""
# 5. 精简依赖
# requirements.txt
"""
# 仅安装必需的包
langserve==0.0.30
langchain==0.1.0
openai==1.0.0
fastapi==0.104.0
uvicorn==0.24.0
# 不安装开发依赖
# pytest
# black
# flake8
"""
# 6. 多阶段构建优化
"""
# 构建阶段
FROM python:3.11 AS builder
WORKDIR /build
COPY requirements.txt .
# 安装到用户目录
RUN pip install --user --no-cache-dir -r requirements.txt
# 运行阶段(更小的镜像)
FROM python:3.11-slim
WORKDIR /app
# 仅复制已安装的包
COPY --from=builder /root/.local /root/.local
# 复制应用
COPY server.py .
COPY chains/ chains/
ENV PATH=/root/.local/bin:$PATH
EXPOSE 8000
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
"""
# 7. 构建优化脚本
# optimize_build.sh
"""
#!/bin/bash
echo "开始优化构建..."
# 1. 清理旧镜像
docker image prune -f
# 2. 使用BuildKit构建(更快)
export DOCKER_BUILDKIT=1
# 3. 构建并显示层信息
docker build \\
--progress=plain \\
--no-cache \\
-t langserve-app:latest .
# 4. 分析镜像大小
docker images langserve-app:latest
# 5. 使用dive分析层
# dive langserve-app:latest
echo "构建完成"
"""
# 8. 压缩镜像
# 导出并压缩
# docker save langserve-app:latest | gzip > langserve-app.tar.gz
# 传输并导入
# gunzip -c langserve-app.tar.gz | docker load
# 9. 信创环境Dockerfile
"""
# 使用国产基础镜像
FROM registry.cn-beijing.aliyuncs.com/kylin/python:3.11
WORKDIR /app
# 使用国内镜像源
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \\
pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn
COPY requirements.txt .
# 安装依赖
RUN pip install --no-cache-dir -r requirements.txt && \\
pip install dmPython && \\
rm -rf /root/.cache/pip
# 复制应用
COPY server.py .
COPY chains/ chains/
# 信创环境变量
ENV XINCHUANG_MODE=true \\
DB_TYPE=dameng \\
LLM_PROVIDER=ollama \\
TZ=Asia/Shanghai
EXPOSE 8000
# 使用国产达梦数据库和本地Ollama
CMD ["python3", "-m", "uvicorn", "server:app", \\
"--host", "0.0.0.0", \\
"--port", "8000", \\
"--workers", "4"]
"""
# 10. 镜像大小对比
"""
构建不同版本的镜像并对比:
docker build -f Dockerfile.full -t langserve:full .
docker build -f Dockerfile.slim -t langserve:slim .
docker build -f Dockerfile.alpine -t langserve:alpine .
docker images | grep langserve
# 输出示例:
# langserve full 1.2GB
# langserve slim 180MB
# langserve alpine 80MB
"""
print("✓ Dockerfile配置完成")
---
02.部署运行
a.容器启动
a.功能说明
使用docker run启动容器,配置端口映射、环境变量、卷挂载。管理容器生命周期(启动、停止、重启)。查看容器日志和状态。容器启动是部署的核心操作。
b.代码示例
---
# 1. 基础启动
# docker run -d \\
# --name langserve \\
# -p 8000:8000 \\
# langserve-app:latest
# 2. 完整启动配置
# docker run -d \\
# --name langserve-prod \\
# -p 8000:8000 \\
# -e OPENAI_API_KEY=your_key \\
# -e LANGCHAIN_TRACING_V2=true \\
# -e ENVIRONMENT=production \\
# -v $(pwd)/data:/app/data \\
# -v $(pwd)/logs:/app/logs \\
# --restart unless-stopped \\
# --memory="2g" \\
# --cpus="2" \\
# langserve-app:latest
# 3. 使用.env文件
# 创建.env文件
"""
OPENAI_API_KEY=your_key
LANGCHAIN_TRACING_V2=true
LANGCHAIN_PROJECT=langserve_prod
"""
# 启动时加载
# docker run -d \\
# --name langserve \\
# -p 8000:8000 \\
# --env-file .env \\
# langserve-app:latest
# 4. 容器管理命令
# 查看运行容器
# docker ps
# 查看所有容器
# docker ps -a
# 查看容器日志
# docker logs langserve
# docker logs -f langserve # 实时日志
# docker logs --tail 100 langserve # 最后100行
# 进入容器
# docker exec -it langserve /bin/bash
# 查看容器资源使用
# docker stats langserve
# 停止容器
# docker stop langserve
# 启动容器
# docker start langserve
# 重启容器
# docker restart langserve
# 删除容器
# docker rm langserve
# docker rm -f langserve # 强制删除
# 5. docker-compose管理
# 启动所有服务
# docker-compose up -d
# 查看状态
# docker-compose ps
# 查看日志
# docker-compose logs -f langserve
# 重启服务
# docker-compose restart langserve
# 停止所有服务
# docker-compose down
# 停止并删除数据卷
# docker-compose down -v
# 6. 健康检查
# 检查容器健康状态
# docker inspect --format='{{.State.Health.Status}}' langserve
# 查看健康检查日志
# docker inspect langserve | grep -A 10 Health
# 7. 容器更新
# 拉取新镜像
# docker pull langserve-app:latest
# 停止旧容器
# docker stop langserve
# 删除旧容器
# docker rm langserve
# 启动新容器
# docker run -d \\
# --name langserve \\
# -p 8000:8000 \\
# --env-file .env \\
# langserve-app:latest
# 或使用docker-compose
# docker-compose pull
# docker-compose up -d
# 8. 零停机更新
# 使用蓝绿部署
# 启动新版本(绿)
# docker run -d \\
# --name langserve-green \\
# -p 8001:8000 \\
# langserve-app:v2
# 测试新版本
# curl http://localhost:8001/health
# 切换流量(修改Nginx配置或更新端口映射)
# ...
# 停止旧版本(蓝)
# docker stop langserve-blue
# 9. 信创环境启动脚本
# start_xinchuang.sh
"""
#!/bin/bash
echo "启动信创环境LangServe..."
# 1. 检查网络
docker network inspect xinchuang-net >/dev/null 2>&1 || \\
docker network create xinchuang-net
# 2. 启动达梦数据库
docker run -d \\
--name dameng \\
--network xinchuang-net \\
-p 5236:5236 \\
-e DM_PWD=SYSDBA \\
-v dameng-data:/opt/dmdbms/data \\
dameng/dm8:latest
echo "等待达梦数据库启动..."
sleep 10
# 3. 启动Ollama
docker run -d \\
--name ollama \\
--network xinchuang-net \\
-p 11434:11434 \\
-v ollama-data:/root/.ollama \\
ollama/ollama:latest
echo "等待Ollama启动..."
sleep 5
# 拉取Qwen模型
docker exec ollama ollama pull qwen:7b
# 4. 启动LangServe
docker run -d \\
--name langserve \\
--network xinchuang-net \\
-p 8000:8000 \\
-e XINCHUANG_MODE=true \\
-e DM_HOST=dameng \\
-e OLLAMA_HOST=ollama \\
langserve-xinchuang:latest
echo "✓ 信创环境启动完成"
echo "访问:http://localhost:8000/docs"
"""
# chmod +x start_xinchuang.sh
# ./start_xinchuang.sh
# 10. 监控脚本
# monitor.sh
"""
#!/bin/bash
echo "LangServe容器监控"
echo "==================="
# 容器状态
echo "容器状态:"
docker ps --filter name=langserve --format "table {{.Names}}\\t{{.Status}}\\t{{.Ports}}"
# 资源使用
echo -e "\\n资源使用:"
docker stats --no-stream langserve
# 健康检查
echo -e "\\n健康状态:"
docker inspect --format='{{.State.Health.Status}}' langserve
# 最近日志
echo -e "\\n最近日志(最后20行):"
docker logs --tail 20 langserve
"""
# chmod +x monitor.sh
# ./monitor.sh
print("✓ 容器启动配置完成")
---
6.2 负载均衡
01.Nginx配置
a.反向代理
a.功能说明
使用Nginx作为反向代理,分发请求到多个LangServe实例。提供负载均衡、SSL终止、静态文件服务。提升系统可用性和性能。Nginx是生产环境的标准网关。
b.代码示例
---
# 1. 基础Nginx配置
# nginx.conf
"""
upstream langserve_backend {
server localhost:8001;
server localhost:8002;
server localhost:8003;
}
server {
listen 80;
server_name api.example.com;
location / {
proxy_pass http://langserve_backend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
}
"""
# 2. 负载均衡策略
# 轮询(默认)
"""
upstream langserve_backend {
server localhost:8001;
server localhost:8002;
server localhost:8003;
}
"""
# 权重
"""
upstream langserve_backend {
server localhost:8001 weight=3; # 权重3
server localhost:8002 weight=2; # 权重2
server localhost:8003 weight=1; # 权重1
}
"""
# IP哈希
"""
upstream langserve_backend {
ip_hash; # 同一IP请求同一后端
server localhost:8001;
server localhost:8002;
server localhost:8003;
}
"""
# 最少连接
"""
upstream langserve_backend {
least_conn; # 选择连接数最少的后端
server localhost:8001;
server localhost:8002;
server localhost:8003;
}
"""
# 3. 健康检查
"""
upstream langserve_backend {
server localhost:8001 max_fails=3 fail_timeout=30s;
server localhost:8002 max_fails=3 fail_timeout=30s;
server localhost:8003 max_fails=3 fail_timeout=30s;
}
"""
# 4. 流式请求配置
"""
server {
listen 80;
server_name api.example.com;
location /chat/stream {
proxy_pass http://langserve_backend;
proxy_http_version 1.1;
proxy_set_header Connection "";
proxy_buffering off; # 关键:禁用缓冲
proxy_cache off;
proxy_set_header X-Accel-Buffering no;
# 超时设置
proxy_connect_timeout 60s;
proxy_send_timeout 60s;
proxy_read_timeout 300s; # 流式请求可能很长
}
location / {
proxy_pass http://langserve_backend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
}
}
"""
# 5. SSL配置
"""
server {
listen 443 ssl http2;
server_name api.example.com;
ssl_certificate /etc/nginx/ssl/cert.pem;
ssl_certificate_key /etc/nginx/ssl/key.pem;
ssl_protocols TLSv1.2 TLSv1.3;
ssl_ciphers HIGH:!aNULL:!MD5;
ssl_prefer_server_ciphers on;
location / {
proxy_pass http://langserve_backend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-Proto https;
}
}
# HTTP重定向到HTTPS
server {
listen 80;
server_name api.example.com;
return 301 https://$server_name$request_uri;
}
"""
# 6. 缓存配置
"""
proxy_cache_path /var/cache/nginx levels=1:2 keys_zone=langserve_cache:10m max_size=1g inactive=60m;
server {
listen 80;
server_name api.example.com;
# 不缓存动态API
location /chat/ {
proxy_cache off;
proxy_pass http://langserve_backend;
}
# 缓存静态资源
location /static/ {
proxy_cache langserve_cache;
proxy_cache_valid 200 1h;
proxy_pass http://langserve_backend;
}
}
"""
# 7. 限流配置
"""
# 定义限流区域
limit_req_zone $binary_remote_addr zone=langserve_limit:10m rate=10r/s;
server {
listen 80;
location / {
# 应用限流
limit_req zone=langserve_limit burst=20 nodelay;
proxy_pass http://langserve_backend;
}
}
"""
# 8. 完整生产配置
"""
# nginx.conf
user nginx;
worker_processes auto;
error_log /var/log/nginx/error.log warn;
pid /var/run/nginx.pid;
events {
worker_connections 1024;
}
http {
include /etc/nginx/mime.types;
default_type application/octet-stream;
log_format main '$remote_addr - $remote_user [$time_local] "$request" '
'$status $body_bytes_sent "$http_referer" '
'"$http_user_agent" "$http_x_forwarded_for" '
'rt=$request_time';
access_log /var/log/nginx/access.log main;
sendfile on;
tcp_nopush on;
keepalive_timeout 65;
gzip on;
upstream langserve_backend {
least_conn;
server langserve1:8000 max_fails=3 fail_timeout=30s;
server langserve2:8000 max_fails=3 fail_timeout=30s;
server langserve3:8000 max_fails=3 fail_timeout=30s;
keepalive 32;
}
server {
listen 80;
server_name api.example.com;
location / {
proxy_pass http://langserve_backend;
proxy_http_version 1.1;
proxy_set_header Connection "";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# 超时
proxy_connect_timeout 10s;
proxy_send_timeout 60s;
proxy_read_timeout 60s;
}
location /chat/stream {
proxy_pass http://langserve_backend;
proxy_buffering off;
proxy_cache off;
proxy_read_timeout 300s;
}
location /health {
access_log off;
proxy_pass http://langserve_backend;
}
}
}
"""
# 9. 信创环境Nginx配置
"""
# 信创环境nginx.conf
upstream xinchuang_langserve {
least_conn;
server 192.168.1.101:8000 max_fails=3 fail_timeout=30s;
server 192.168.1.102:8000 max_fails=3 fail_timeout=30s;
server 192.168.1.103:8000 max_fails=3 fail_timeout=30s;
}
server {
listen 80;
server_name internal-ai.company.com;
# 仅允许内网访问
allow 192.168.0.0/16;
allow 10.0.0.0/8;
deny all;
location / {
proxy_pass http://xinchuang_langserve;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Server-Type "Kylin-Nginx";
# 超时配置(本地LLM可能较慢)
proxy_connect_timeout 30s;
proxy_send_timeout 120s;
proxy_read_timeout 120s;
}
location /local/chat/stream {
proxy_pass http://xinchuang_langserve;
proxy_buffering off;
proxy_cache off;
proxy_read_timeout 600s; # 本地模型可能很慢
}
# 健康检查
location /health {
access_log off;
proxy_pass http://xinchuang_langserve;
}
}
"""
# 10. Nginx管理脚本
# nginx_manage.sh
"""
#!/bin/bash
case "$1" in
start)
docker run -d \\
--name nginx-lb \\
-p 80:80 \\
-v $(pwd)/nginx.conf:/etc/nginx/nginx.conf:ro \\
--network xinchuang-net \\
nginx:alpine
echo "Nginx已启动"
;;
reload)
docker exec nginx-lb nginx -s reload
echo "Nginx配置已重载"
;;
stop)
docker stop nginx-lb
docker rm nginx-lb
echo "Nginx已停止"
;;
logs)
docker logs -f nginx-lb
;;
test)
docker exec nginx-lb nginx -t
;;
*)
echo "用法: $0 {start|reload|stop|logs|test}"
exit 1
esac
"""
# chmod +x nginx_manage.sh
# ./nginx_manage.sh start
print("✓ Nginx负载均衡配置完成")
---
b.负载策略
a.功能说明
选择合适的负载均衡策略分配请求。轮询、权重、最少连接、IP哈希等策略。根据业务特点和服务器性能选择策略。合理的负载策略提升系统性能和稳定性。
b.代码示例
---
# 1. 轮询(Round Robin)
# 默认策略,按顺序分配请求
"""
upstream backend {
server server1:8000;
server server2:8000;
server server3:8000;
}
# 请求分配:
# Request 1 -> server1
# Request 2 -> server2
# Request 3 -> server3
# Request 4 -> server1
# ...
"""
# 2. 加权轮询(Weighted Round Robin)
# 性能好的服务器分配更多请求
"""
upstream backend {
server server1:8000 weight=5; # 高性能服务器
server server2:8000 weight=3; # 中等性能
server server3:8000 weight=2; # 较低性能
}
# 10个请求中:
# server1处理5个
# server2处理3个
# server3处理2个
"""
# 3. 最少连接(Least Connections)
# 选择当前连接数最少的服务器
"""
upstream backend {
least_conn;
server server1:8000;
server server2:8000;
server server3:8000;
}
# 适用于请求处理时间差异大的场景
"""
# 4. IP哈希(IP Hash)
# 同一IP请求同一服务器,保持会话
"""
upstream backend {
ip_hash;
server server1:8000;
server server2:8000;
server server3:8000;
}
# 同一用户的请求总是到同一服务器
# 适用于有状态服务
"""
# 5. 一致性哈希
"""
upstream backend {
hash $request_uri consistent;
server server1:8000;
server server2:8000;
server server3:8000;
}
# 根据请求URI哈希
# 相同URL请求同一服务器
# 适用于缓存场景
"""
# 6. 备份服务器
"""
upstream backend {
server server1:8000;
server server2:8000;
server server3:8000 backup; # 仅当前两个都不可用时使用
}
"""
# 7. 服务器状态控制
"""
upstream backend {
server server1:8000;
server server2:8000 down; # 临时下线
server server3:8000 max_fails=3 fail_timeout=30s;
server server4:8000 max_conns=100; # 最大连接数
}
"""
# 8. 动态负载均衡
# 使用Nginx Plus(商业版)或Consul
"""
# 使用Consul服务发现
# 需要nginx-upsync-module模块
upstream backend {
upsync consul://consul-server:8500/v1/kv/upstreams/langserve;
upsync_timeout=6m;
upsync_interval=500ms;
upsync_type=consul;
}
"""
# 9. 会话保持
"""
upstream backend {
ip_hash; # 或使用sticky模块
server server1:8000;
server server2:8000;
server server3:8000;
}
server {
listen 80;
location / {
proxy_pass http://backend;
# 添加会话Cookie
proxy_cookie_path / "/; HTTPOnly; Secure";
}
}
"""
# 10. 信创环境负载配置
"""
# 信创环境nginx配置
upstream xinchuang_langserve {
# 使用最少连接策略(本地LLM处理时间可能不同)
least_conn;
# 内网服务器
server 192.168.1.101:8000 weight=3 max_fails=2 fail_timeout=20s;
server 192.168.1.102:8000 weight=3 max_fails=2 fail_timeout=20s;
server 192.168.1.103:8000 weight=2 max_fails=2 fail_timeout=20s; # 较低配置
server 192.168.1.104:8000 backup; # 备份服务器
keepalive 16;
}
server {
listen 80;
server_name internal-ai.company.com;
# IP白名单(仅内网)
allow 192.168.0.0/16;
allow 10.0.0.0/8;
deny all;
# 限流
limit_req_zone $binary_remote_addr zone=xc_limit:10m rate=20r/s;
limit_req zone=xc_limit burst=50 nodelay;
location / {
proxy_pass http://xinchuang_langserve;
proxy_http_version 1.1;
proxy_set_header Connection "";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Environment "Xinchuang";
# 超时(本地模型较慢)
proxy_connect_timeout 30s;
proxy_send_timeout 180s;
proxy_read_timeout 180s;
}
# 流式端点
location ~ ^/.*/(stream|astream) {
proxy_pass http://xinchuang_langserve;
proxy_buffering off;
proxy_cache off;
proxy_read_timeout 600s;
}
# 健康检查
location /health {
access_log off;
proxy_pass http://xinchuang_langserve;
}
}
"""
# 启动Nginx
# docker run -d \\
# --name nginx-xinchuang \\
# -p 80:80 \\
# -v $(pwd)/nginx.conf:/etc/nginx/nginx.conf:ro \\
# --network xinchuang-net \\
# nginx:alpine
print("✓ 负载均衡策略配置完成")
---
02.高可用部署
a.多实例部署
a.功能说明
部署多个LangServe实例,提供冗余和容错能力。使用不同端口或不同服务器。配置自动故障转移。多实例是高可用的基础。
b.代码示例
---
# 1. docker-compose多实例
# docker-compose.yml
"""
version: '3.8'
services:
langserve-1:
build: .
ports:
- "8001:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- INSTANCE_ID=1
restart: unless-stopped
langserve-2:
build: .
ports:
- "8002:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- INSTANCE_ID=2
restart: unless-stopped
langserve-3:
build: .
ports:
- "8003:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- INSTANCE_ID=3
restart: unless-stopped
nginx:
image: nginx:alpine
ports:
- "80:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
depends_on:
- langserve-1
- langserve-2
- langserve-3
restart: unless-stopped
"""
# 启动
# docker-compose up -d --scale langserve=3
# 2. Kubernetes部署
# deployment.yaml
"""
apiVersion: apps/v1
kind: Deployment
metadata:
name: langserve
spec:
replicas: 3
selector:
matchLabels:
app: langserve
template:
metadata:
labels:
app: langserve
spec:
containers:
- name: langserve
image: langserve-app:latest
ports:
- containerPort: 8000
env:
- name: OPENAI_API_KEY
valueFrom:
secretKeyRef:
name: langserve-secret
key: openai-api-key
resources:
limits:
memory: "2Gi"
cpu: "1000m"
requests:
memory: "1Gi"
cpu: "500m"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 10
periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
name: langserve-service
spec:
selector:
app: langserve
ports:
- protocol: TCP
port: 80
targetPort: 8000
type: LoadBalancer
"""
# 部署
# kubectl apply -f deployment.yaml
# 扩缩容
# kubectl scale deployment langserve --replicas=5
# 3. 健康检查端点
from fastapi import FastAPI
import os
app = FastAPI()
@app.get("/health")
def health_check():
"""健康检查"""
instance_id = os.getenv("INSTANCE_ID", "unknown")
# 检查各组件状态
status = {
"status": "healthy",
"instance_id": instance_id,
"timestamp": time.time()
}
# 检查数据库连接(如果有)
try:
# 测试数据库...
status["database"] = "connected"
except:
status["database"] = "disconnected"
status["status"] = "degraded"
# 检查LLM服务(如果有)
try:
# 测试LLM...
status["llm"] = "available"
except:
status["llm"] = "unavailable"
status["status"] = "degraded"
return status
# 4. 故障转移测试
# test_failover.sh
"""
#!/bin/bash
echo "测试故障转移..."
# 停止一个实例
docker stop langserve-1
# 继续发送请求
for i in {1..10}; do
curl http://localhost/chat/invoke \\
-X POST \\
-H "Content-Type: application/json" \\
-d '{"input": {"question": "测试'$i'"}}'
echo ""
sleep 1
done
# 恢复实例
docker start langserve-1
echo "故障转移测试完成"
"""
# 5. 滚动更新
# rolling_update.sh
"""
#!/bin/bash
echo "滚动更新LangServe实例..."
INSTANCES=("langserve-1" "langserve-2" "langserve-3")
for instance in "${INSTANCES[@]}"; do
echo "更新$instance..."
# 停止实例
docker stop $instance
# 删除旧容器
docker rm $instance
# 启动新版本
docker run -d \\
--name $instance \\
-p ${instance#langserve-}8000:8000 \\
--env-file .env \\
langserve-app:v2
# 等待启动
sleep 10
# 健康检查
if curl -f http://localhost:${instance#langserve-}8000/health; then
echo "✓ $instance更新成功"
else
echo "✗ $instance更新失败,回滚..."
docker stop $instance
docker rm $instance
docker run -d \\
--name $instance \\
-p ${instance#langserve-}8000:8000 \\
langserve-app:v1
break
fi
echo "等待30秒后更新下一个实例..."
sleep 30
done
echo "滚动更新完成"
"""
# 6. 信创环境多实例部署
# xinchuang_deploy.yaml
"""
version: '3.8'
services:
langserve-1:
image: langserve-xinchuang:latest
container_name: langserve-kylin-1
ports:
- "8001:8000"
environment:
- XINCHUANG_MODE=true
- INSTANCE_ID=kylin-1
- DM_HOST=dameng
- OLLAMA_HOST=ollama
networks:
- xc-net
restart: unless-stopped
deploy:
resources:
limits:
cpus: '2'
memory: 4G
langserve-2:
image: langserve-xinchuang:latest
container_name: langserve-kylin-2
ports:
- "8002:8000"
environment:
- XINCHUANG_MODE=true
- INSTANCE_ID=kylin-2
- DM_HOST=dameng
- OLLAMA_HOST=ollama
networks:
- xc-net
restart: unless-stopped
deploy:
resources:
limits:
cpus: '2'
memory: 4G
langserve-3:
image: langserve-xinchuang:latest
container_name: langserve-kylin-3
ports:
- "8003:8000"
environment:
- XINCHUANG_MODE=true
- INSTANCE_ID=kylin-3
- DM_HOST=dameng
- OLLAMA_HOST=ollama
networks:
- xc-net
restart: unless-stopped
deploy:
resources:
limits:
cpus: '1'
memory: 2G
dameng:
image: dameng/dm8:latest
container_name: dameng-db
ports:
- "5236:5236"
environment:
- DM_PWD=SYSDBA
volumes:
- dameng-data:/opt/dmdbms/data
networks:
- xc-net
restart: unless-stopped
ollama:
image: ollama/ollama:latest
container_name: ollama-service
ports:
- "11434:11434"
volumes:
- ollama-data:/root/.ollama
networks:
- xc-net
restart: unless-stopped
nginx:
image: nginx:alpine
container_name: nginx-lb
ports:
- "80:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
depends_on:
- langserve-1
- langserve-2
- langserve-3
networks:
- xc-net
restart: unless-stopped
networks:
xc-net:
driver: bridge
volumes:
dameng-data:
ollama-data:
"""
# 部署
# docker-compose -f xinchuang_deploy.yaml up -d
# 查看状态
# docker-compose -f xinchuang_deploy.yaml ps
# 查看日志
# docker-compose -f xinchuang_deploy.yaml logs -f langserve-1
print("✓ 高可用部署配置完成")
---
6.3 监控日志
01.日志收集
a.集中式日志
a.功能说明
集中收集所有服务实例的日志,统一管理和查询。使用ELK(Elasticsearch、Logstash、Kibana)或Loki等工具。支持日志搜索、过滤、聚合分析。集中式日志是分布式系统的必备设施。
b.代码示例
---
# 1. 配置JSON格式日志
# server.py
import logging
import json
from datetime import datetime
class JsonFormatter(logging.Formatter):
"""JSON格式化器"""
def format(self, record):
log_obj = {
"timestamp": datetime.utcnow().isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
"module": record.module,
"function": record.funcName,
"line": record.lineno
}
if record.exc_info:
log_obj["exception"] = self.formatException(record.exc_info)
return json.dumps(log_obj)
# 配置日志
handler = logging.StreamHandler()
handler.setFormatter(JsonFormatter())
logger = logging.getLogger(__name__)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
# 2. Docker日志驱动
# docker-compose.yml
"""
version: '3.8'
services:
langserve:
build: .
ports:
- "8000:8000"
logging:
driver: "json-file"
options:
max-size: "10m"
max-file: "3"
"""
# 3. Fluentd日志收集
# docker-compose.yml
"""
version: '3.8'
services:
langserve:
build: .
ports:
- "8000:8000"
logging:
driver: "fluentd"
options:
fluentd-address: localhost:24224
tag: langserve.{{.ID}}
fluentd:
image: fluent/fluentd:latest
ports:
- "24224:24224"
volumes:
- ./fluentd.conf:/fluentd/etc/fluent.conf
- fluentd-logs:/fluentd/log
volumes:
fluentd-logs:
"""
# fluentd.conf
"""
<source>
@type forward
port 24224
</source>
<match langserve.**>
@type file
path /fluentd/log/langserve
<format>
@type json
</format>
<buffer>
timekey 1h
timekey_wait 10m
</buffer>
</match>
"""
# 4. Prometheus监控
from prometheus_client import Counter, Histogram, generate_latest
from fastapi import Response
# 定义指标
REQUEST_COUNT = Counter(
'langserve_requests_total',
'Total requests',
['method', 'endpoint', 'status']
)
REQUEST_DURATION = Histogram(
'langserve_request_duration_seconds',
'Request duration',
['method', 'endpoint']
)
@app.middleware("http")
async def prometheus_middleware(request: Request, call_next):
"""Prometheus指标收集"""
start_time = time.time()
response = await call_next(request)
duration = time.time() - start_time
# 记录指标
REQUEST_COUNT.labels(
method=request.method,
endpoint=request.url.path,
status=response.status_code
).inc()
REQUEST_DURATION.labels(
method=request.method,
endpoint=request.url.path
).observe(duration)
return response
@app.get("/metrics")
def metrics():
"""暴露Prometheus指标"""
return Response(
content=generate_latest(),
media_type="text/plain"
)
# 5. Grafana仪表板
# prometheus.yml
"""
global:
scrape_interval: 15s
scrape_configs:
- job_name: 'langserve'
static_configs:
- targets: ['langserve:8000']
"""
# 添加到docker-compose
"""
services:
prometheus:
image: prom/prometheus:latest
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
- prometheus-data:/prometheus
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
volumes:
- grafana-data:/var/lib/grafana
volumes:
prometheus-data:
grafana-data:
"""
# 6. 日志查询API
@app.get("/logs/recent")
def get_recent_logs(limit: int = 100):
"""获取最近日志"""
# 从日志文件读取
logs = []
try:
with open("langserve.log", "r") as f:
lines = f.readlines()[-limit:]
for line in lines:
try:
logs.append(json.loads(line))
except:
logs.append({"raw": line.strip()})
except:
pass
return {"logs": logs, "count": len(logs)}
@app.get("/logs/errors")
def get_error_logs(limit: int = 50):
"""获取错误日志"""
logs = []
try:
with open("langserve.log", "r") as f:
for line in f:
try:
log_obj = json.loads(line)
if log_obj.get("level") in ["ERROR", "CRITICAL"]:
logs.append(log_obj)
if len(logs) >= limit:
break
except:
pass
except:
pass
return {"errors": logs, "count": len(logs)}
# 7. 实时日志流
from fastapi.responses import StreamingResponse
import asyncio
@app.get("/logs/stream")
async def stream_logs():
"""实时日志流"""
async def log_generator():
# 实际应tail -f日志文件
import subprocess
proc = subprocess.Popen(
['tail', '-f', 'langserve.log'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
while True:
line = proc.stdout.readline()
if line:
yield f"data: {line.decode()}\\n\\n"
await asyncio.sleep(0.1)
return StreamingResponse(
log_generator(),
media_type="text/event-stream"
)
# 8. 信创环境日志配置
# 日志存储到达梦数据库
import dmPython
class DamengLogHandler(logging.Handler):
"""达梦数据库日志处理器"""
def __init__(self, connection_string):
super().__init__()
self.conn = dmPython.connect(connection_string)
self._init_table()
def _init_table(self):
"""初始化日志表"""
cursor = self.conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS application_logs (
id BIGINT IDENTITY(1,1) PRIMARY KEY,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
level VARCHAR(20),
logger VARCHAR(100),
message TEXT,
module VARCHAR(100),
function VARCHAR(100),
line_no INTEGER,
exception TEXT
)
""")
self.conn.commit()
def emit(self, record):
"""写入日志"""
try:
cursor = self.conn.cursor()
cursor.execute("""
INSERT INTO application_logs
(level, logger, message, module, function, line_no, exception)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
record.levelname,
record.name,
record.getMessage(),
record.module,
record.funcName,
record.lineno,
self.format(record.exc_info) if record.exc_info else None
))
self.conn.commit()
except Exception:
self.handleError(record)
# 使用
dm_handler = DamengLogHandler(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236"
)
logger.addHandler(dm_handler)
# 查询日志API
@app.get("/xinchuang/logs")
def query_xinchuang_logs(level: str = None, limit: int = 100):
"""查询信创环境日志"""
conn = dmPython.connect(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236"
)
cursor = conn.cursor()
if level:
cursor.execute(
"SELECT * FROM application_logs WHERE level = ? ORDER BY timestamp DESC LIMIT ?",
(level, limit)
)
else:
cursor.execute(
"SELECT * FROM application_logs ORDER BY timestamp DESC LIMIT ?",
(limit,)
)
logs = cursor.fetchall()
return {
"logs": [
{
"id": row[0],
"timestamp": str(row[1]),
"level": row[2],
"logger": row[3],
"message": row[4]
}
for row in logs
],
"count": len(logs)
}
print("✓ 日志收集配置完成")
---
b.监控指标
a.功能说明
收集系统运行指标,监控服务健康状态。包括请求数、响应时间、错误率、资源使用等。使用Prometheus、Grafana等工具可视化。监控指标是运维决策的数据基础。
b.代码示例
---
from prometheus_client import Counter, Histogram, Gauge, Summary
from fastapi import FastAPI, Request
import time
import psutil
app = FastAPI()
# 1. 定义Prometheus指标
# 计数器:请求总数
REQUEST_TOTAL = Counter(
'langserve_requests_total',
'Total number of requests',
['method', 'endpoint', 'status']
)
# 直方图:响应时间分布
REQUEST_DURATION = Histogram(
'langserve_request_duration_seconds',
'Request duration in seconds',
['method', 'endpoint'],
buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0]
)
# 仪表:当前活跃请求数
ACTIVE_REQUESTS = Gauge(
'langserve_active_requests',
'Number of active requests'
)
# 摘要:请求大小
REQUEST_SIZE = Summary(
'langserve_request_size_bytes',
'Request size in bytes'
)
# 2. 收集指标的中间件
@app.middleware("http")
async def collect_metrics(request: Request, call_next):
"""收集指标中间件"""
ACTIVE_REQUESTS.inc()
start_time = time.time()
try:
response = await call_next(request)
duration = time.time() - start_time
# 记录指标
REQUEST_TOTAL.labels(
method=request.method,
endpoint=request.url.path,
status=response.status_code
).inc()
REQUEST_DURATION.labels(
method=request.method,
endpoint=request.url.path
).observe(duration)
return response
finally:
ACTIVE_REQUESTS.dec()
# 3. 系统资源指标
SYSTEM_CPU = Gauge('system_cpu_percent', 'System CPU usage')
SYSTEM_MEMORY = Gauge('system_memory_percent', 'System memory usage')
SYSTEM_DISK = Gauge('system_disk_percent', 'System disk usage')
import asyncio
async def collect_system_metrics():
"""收集系统指标"""
while True:
# CPU
cpu_percent = psutil.cpu_percent(interval=1)
SYSTEM_CPU.set(cpu_percent)
# 内存
memory = psutil.virtual_memory()
SYSTEM_MEMORY.set(memory.percent)
# 磁盘
disk = psutil.disk_usage('/')
SYSTEM_DISK.set(disk.percent)
await asyncio.sleep(10)
@app.on_event("startup")
async def start_metrics_collection():
"""启动时开始收集系统指标"""
asyncio.create_task(collect_system_metrics())
# 4. LLM调用指标
LLM_REQUESTS = Counter(
'langserve_llm_requests_total',
'Total LLM requests',
['model', 'status']
)
LLM_TOKENS = Counter(
'langserve_llm_tokens_total',
'Total LLM tokens used',
['model', 'type'] # type: prompt, completion
)
LLM_DURATION = Histogram(
'langserve_llm_duration_seconds',
'LLM request duration',
['model']
)
# 在LLM调用时记录
from langchain.callbacks import BaseCallbackHandler
class MetricsCallbackHandler(BaseCallbackHandler):
"""指标回调处理器"""
def on_llm_start(self, serialized, prompts, **kwargs):
"""LLM开始"""
self.start_time = time.time()
def on_llm_end(self, response, **kwargs):
"""LLM结束"""
duration = time.time() - self.start_time
model = response.llm_output.get("model_name", "unknown")
LLM_REQUESTS.labels(model=model, status="success").inc()
LLM_DURATION.labels(model=model).observe(duration)
# Token统计
token_usage = response.llm_output.get("token_usage", {})
if token_usage:
LLM_TOKENS.labels(model=model, type="prompt").inc(
token_usage.get("prompt_tokens", 0)
)
LLM_TOKENS.labels(model=model, type="completion").inc(
token_usage.get("completion_tokens", 0)
)
def on_llm_error(self, error, **kwargs):
"""LLM错误"""
LLM_REQUESTS.labels(model="unknown", status="error").inc()
# 5. 暴露指标端点
from prometheus_client import generate_latest
@app.get("/metrics")
def prometheus_metrics():
"""Prometheus指标端点"""
return Response(
content=generate_latest(),
media_type="text/plain; charset=utf-8"
)
# 6. Grafana仪表板配置
# grafana_dashboard.json
"""
{
"dashboard": {
"title": "LangServe监控",
"panels": [
{
"title": "请求速率",
"targets": [{
"expr": "rate(langserve_requests_total[5m])"
}]
},
{
"title": "响应时间",
"targets": [{
"expr": "histogram_quantile(0.95, rate(langserve_request_duration_seconds_bucket[5m]))"
}]
},
{
"title": "错误率",
"targets": [{
"expr": "rate(langserve_requests_total{status=~\"5..\"}[5m])"
}]
},
{
"title": "活跃请求",
"targets": [{
"expr": "langserve_active_requests"
}]
}
]
}
}
"""
# 7. 告警规则
# prometheus_rules.yml
"""
groups:
- name: langserve_alerts
interval: 30s
rules:
- alert: HighErrorRate
expr: rate(langserve_requests_total{status=~"5.."}[5m]) > 0.1
for: 5m
labels:
severity: critical
annotations:
summary: "LangServe错误率过高"
description: "5分钟内错误率超过10%"
- alert: SlowResponse
expr: histogram_quantile(0.95, rate(langserve_request_duration_seconds_bucket[5m])) > 5
for: 5m
labels:
severity: warning
annotations:
summary: "LangServe响应缓慢"
description: "P95响应时间超过5秒"
- alert: HighCPUUsage
expr: system_cpu_percent > 80
for: 5m
labels:
severity: warning
annotations:
summary: "CPU使用率过高"
"""
# 8. 信创环境监控指标
# 定义信创特定指标
DAMENG_QUERIES = Counter(
'xinchuang_dameng_queries_total',
'Total Dameng DB queries',
['operation', 'status']
)
OLLAMA_REQUESTS = Counter(
'xinchuang_ollama_requests_total',
'Total Ollama requests',
['model', 'status']
)
XINCHUANG_RESPONSE_TIME = Histogram(
'xinchuang_response_time_seconds',
'Xinchuang component response time',
['component'], # dameng, ollama, app
buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0]
)
# 收集信创指标
@app.middleware("http")
async def collect_xinchuang_metrics(request: Request, call_next):
"""收集信创指标"""
# 记录各组件时间
db_time = 0
llm_time = 0
start_time = time.time()
# 模拟数据库操作
db_start = time.time()
# ... 达梦数据库操作
db_time = time.time() - db_start
XINCHUANG_RESPONSE_TIME.labels(component="dameng").observe(db_time)
DAMENG_QUERIES.labels(operation="query", status="success").inc()
# 模拟LLM操作
llm_start = time.time()
response = await call_next(request)
llm_time = time.time() - llm_start
XINCHUANG_RESPONSE_TIME.labels(component="ollama").observe(llm_time)
OLLAMA_REQUESTS.labels(model="qwen:7b", status="success").inc()
# 总时间
total_time = time.time() - start_time
XINCHUANG_RESPONSE_TIME.labels(component="app").observe(total_time)
return response
print("✓ 监控指标配置完成")
---
02.性能监控
a.实时监控
a.功能说明
实时监控服务运行状态,及时发现和处理问题。监控CPU、内存、网络、磁盘等资源。追踪请求量、响应时间、错误率等业务指标。实时监控是保障服务稳定的关键。
b.代码示例
---
from fastapi import FastAPI
import psutil
import time
from collections import deque
from datetime import datetime
app = FastAPI()
# 1. 实时状态监控
class RealTimeMonitor:
"""实时监控器"""
def __init__(self):
self.request_count = 0
self.error_count = 0
self.response_times = deque(maxlen=1000)
self.start_time = time.time()
def record_request(self, duration: float, is_error: bool = False):
"""记录请求"""
self.request_count += 1
if is_error:
self.error_count += 1
self.response_times.append(duration)
def get_stats(self):
"""获取统计"""
uptime = time.time() - self.start_time
stats = {
"uptime_seconds": uptime,
"total_requests": self.request_count,
"total_errors": self.error_count,
"error_rate": self.error_count / self.request_count if self.request_count > 0 else 0,
"requests_per_second": self.request_count / uptime if uptime > 0 else 0
}
if self.response_times:
import statistics
stats["avg_response_time"] = statistics.mean(self.response_times)
stats["median_response_time"] = statistics.median(self.response_times)
stats["min_response_time"] = min(self.response_times)
stats["max_response_time"] = max(self.response_times)
# 系统资源
stats["system"] = {
"cpu_percent": psutil.cpu_percent(),
"memory_percent": psutil.virtual_memory().percent,
"disk_percent": psutil.disk_usage('/').percent
}
return stats
monitor = RealTimeMonitor()
@app.middleware("http")
async def monitoring_middleware(request: Request, call_next):
"""监控中间件"""
start_time = time.time()
try:
response = await call_next(request)
duration = time.time() - start_time
is_error = response.status_code >= 500
monitor.record_request(duration, is_error)
return response
except Exception as e:
duration = time.time() - start_time
monitor.record_request(duration, is_error=True)
raise
@app.get("/monitoring/status")
def get_monitoring_status():
"""获取实时监控状态"""
return monitor.get_stats()
# 2. 实时仪表板
from fastapi.responses import HTMLResponse
@app.get("/dashboard", response_class=HTMLResponse)
def realtime_dashboard():
"""实时监控仪表板"""
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>LangServe实时监控</title>
<style>
body { font-family: Arial; margin: 20px; background: #f5f5f5; }
.container { max-width: 1200px; margin: 0 auto; }
h1 { color: #333; text-align: center; }
.metrics { display: grid; grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); gap: 20px; margin: 20px 0; }
.metric-card { background: white; padding: 20px; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); }
.metric-value { font-size: 32px; font-weight: bold; color: #007bff; }
.metric-label { color: #666; margin-top: 5px; }
.chart { background: white; padding: 20px; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); margin: 20px 0; }
</style>
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
</head>
<body>
<div class="container">
<h1>📊 LangServe实时监控</h1>
<div class="metrics">
<div class="metric-card">
<div class="metric-value" id="total-requests">0</div>
<div class="metric-label">总请求数</div>
</div>
<div class="metric-card">
<div class="metric-value" id="error-rate">0%</div>
<div class="metric-label">错误率</div>
</div>
<div class="metric-card">
<div class="metric-value" id="avg-time">0ms</div>
<div class="metric-label">平均响应时间</div>
</div>
<div class="metric-card">
<div class="metric-value" id="cpu-usage">0%</div>
<div class="metric-label">CPU使用率</div>
</div>
</div>
<div class="chart">
<canvas id="responseTimeChart"></canvas>
</div>
</div>
<script>
// 实时更新数据
async function updateMetrics() {
try {
const response = await fetch('/monitoring/status');
const data = await response.json();
document.getElementById('total-requests').textContent = data.total_requests;
document.getElementById('error-rate').textContent =
(data.error_rate * 100).toFixed(2) + '%';
document.getElementById('avg-time').textContent =
(data.avg_response_time * 1000).toFixed(0) + 'ms';
document.getElementById('cpu-usage').textContent =
data.system.cpu_percent.toFixed(1) + '%';
} catch (error) {
console.error('更新失败:', error);
}
}
// 每2秒更新一次
setInterval(updateMetrics, 2000);
updateMetrics(); // 立即更新
</script>
</body>
</html>
"""
# 3. WebSocket实时推送
from fastapi import WebSocket
@app.websocket("/ws/monitoring")
async def websocket_monitoring(websocket: WebSocket):
"""WebSocket实时监控"""
await websocket.accept()
try:
while True:
# 每秒推送一次监控数据
stats = monitor.get_stats()
await websocket.send_json(stats)
await asyncio.sleep(1)
except Exception as e:
print(f"WebSocket错误:{e}")
finally:
await websocket.close()
# 4. 告警检测
class AlertDetector:
"""告警检测器"""
def __init__(self):
self.thresholds = {
"error_rate": 0.05, # 5%
"avg_response_time": 5.0, # 5秒
"cpu_percent": 80,
"memory_percent": 85
}
def check_alerts(self, stats: dict) -> list:
"""检查告警"""
alerts = []
# 错误率告警
if stats["error_rate"] > self.thresholds["error_rate"]:
alerts.append({
"type": "high_error_rate",
"value": stats["error_rate"],
"threshold": self.thresholds["error_rate"],
"message": f"错误率{stats['error_rate']:.2%}超过阈值"
})
# 响应时间告警
if stats.get("avg_response_time", 0) > self.thresholds["avg_response_time"]:
alerts.append({
"type": "slow_response",
"value": stats["avg_response_time"],
"threshold": self.thresholds["avg_response_time"],
"message": f"平均响应时间{stats['avg_response_time']:.2f}秒超过阈值"
})
# CPU告警
cpu = stats.get("system", {}).get("cpu_percent", 0)
if cpu > self.thresholds["cpu_percent"]:
alerts.append({
"type": "high_cpu",
"value": cpu,
"threshold": self.thresholds["cpu_percent"],
"message": f"CPU使用率{cpu:.1f}%超过阈值"
})
return alerts
alert_detector = AlertDetector()
@app.get("/monitoring/alerts")
def check_current_alerts():
"""检查当前告警"""
stats = monitor.get_stats()
alerts = alert_detector.check_alerts(stats)
return {
"alerts": alerts,
"count": len(alerts),
"timestamp": time.time()
}
# 5. 信创环境实时监控
@app.get("/xinchuang/monitoring")
def xinchuang_monitoring():
"""信创环境实时监控"""
# 获取基础指标
stats = monitor.get_stats()
# 添加信创组件状态
xc_status = {
"dameng": "unknown",
"ollama": "unknown",
"kylin_os": "unknown"
}
# 检测达梦数据库
try:
import dmPython
conn = dmPython.connect(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236"
)
cursor = conn.cursor()
cursor.execute("SELECT 1")
xc_status["dameng"] = "healthy"
except:
xc_status["dameng"] = "unhealthy"
# 检测Ollama
try:
import requests
resp = requests.get("http://localhost:11434/api/tags", timeout=2)
if resp.status_code == 200:
xc_status["ollama"] = "healthy"
except:
xc_status["ollama"] = "unhealthy"
# 系统信息
import platform
xc_status["kylin_os"] = "healthy" if "kylin" in platform.system().lower() else "not_kylin"
stats["xinchuang_components"] = xc_status
return stats
print("✓ 实时监控配置完成")
---
6.4 安全加固
01.访问控制
a.认证授权
a.功能说明
实施严格的认证和授权机制,保护API不被未授权访问。支持API Key、JWT、OAuth等多种认证方式。基于角色的权限控制(RBAC)。认证授权是API安全的第一道防线。
b.代码示例
---
from fastapi import FastAPI, Depends, HTTPException, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
from datetime import datetime, timedelta
import hashlib
app = FastAPI()
security = HTTPBearer()
# 1. API Key认证
API_KEYS = {
"key123": {"user": "user1", "permissions": ["read", "write"]},
"key456": {"user": "user2", "permissions": ["read"]}
}
def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
"""验证API Key"""
api_key = credentials.credentials
if api_key not in API_KEYS:
raise HTTPException(status_code=401, detail="Invalid API Key")
return API_KEYS[api_key]
@app.post("/protected")
def protected_endpoint(user=Depends(verify_api_key)):
"""受保护的端点"""
return {"message": f"Hello {user['user']}"}
# 2. JWT认证
SECRET_KEY = "your-secret-key-change-me"
ALGORITHM = "HS256"
def create_access_token(data: dict, expires_delta: timedelta = None):
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(hours=1)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
"""验证JWT令牌"""
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=401, detail="Invalid token")
return payload
except JWTError:
raise HTTPException(status_code=401, detail="Invalid token")
@app.post("/login")
def login(username: str, password: str):
"""登录获取令牌"""
# 实际应查询数据库验证
if username == "admin" and password == "password":
access_token = create_access_token(
data={"sub": username, "role": "admin"}
)
return {"access_token": access_token, "token_type": "bearer"}
raise HTTPException(status_code=401, detail="Invalid credentials")
@app.get("/secure")
def secure_endpoint(user=Depends(verify_token)):
"""JWT保护的端点"""
return {"user": user}
# 3. 权限控制
def require_permission(permission: str):
"""权限检查装饰器"""
def permission_checker(user=Depends(verify_token)):
user_permissions = user.get("permissions", [])
if permission not in user_permissions:
raise HTTPException(status_code=403, detail="Permission denied")
return user
return permission_checker
@app.post("/admin/action")
def admin_action(user=Depends(require_permission("admin"))):
"""需要admin权限"""
return {"message": "Admin action"}
# 4. RBAC(基于角色的访问控制)
ROLES = {
"user": ["read"],
"editor": ["read", "write"],
"admin": ["read", "write", "delete", "admin"]
}
def get_user_permissions(role: str) -> list:
"""获取用户权限"""
return ROLES.get(role, [])
def require_role(required_role: str):
"""角色检查"""
def role_checker(user=Depends(verify_token)):
user_role = user.get("role", "user")
user_permissions = get_user_permissions(user_role)
required_permissions = get_user_permissions(required_role)
if not set(required_permissions).issubset(set(user_permissions)):
raise HTTPException(status_code=403, detail="Insufficient permissions")
return user
return role_checker
@app.delete("/resource/{id}")
def delete_resource(id: int, user=Depends(require_role("admin"))):
"""删除资源(需要admin角色)"""
return {"message": f"Resource {id} deleted"}
# 5. 密码哈希
import hashlib
import secrets
def hash_password(password: str) -> tuple:
"""哈希密码"""
salt = secrets.token_hex(32)
pwd_hash = hashlib.pbkdf2_hmac(
'sha256',
password.encode('utf-8'),
salt.encode('utf-8'),
100000
)
return pwd_hash.hex(), salt
def verify_password(password: str, pwd_hash: str, salt: str) -> bool:
"""验证密码"""
new_hash = hashlib.pbkdf2_hmac(
'sha256',
password.encode('utf-8'),
salt.encode('utf-8'),
100000
)
return new_hash.hex() == pwd_hash
# 6. 信创环境:国密SM2 JWT
from gmssl import sm2, func
class SM2JWT:
"""基于国密SM2的JWT"""
def __init__(self, private_key: str, public_key: str):
self.sm2_crypt = sm2.CryptSM2(
private_key=bytes.fromhex(private_key),
public_key=bytes.fromhex(public_key)
)
def create_token(self, data: dict) -> str:
"""创建SM2签名的Token"""
import json
import base64
# 添加过期时间
data["exp"] = (datetime.utcnow() + timedelta(hours=1)).timestamp()
payload = json.dumps(data).encode('utf-8')
# SM2签名
signature = self.sm2_crypt.sign(payload, 'utf-8')
# Base64编码
token = base64.urlsafe_b64encode(
payload + b"." + signature
).decode('utf-8')
return token
def verify_token(self, token: str) -> dict:
"""验证SM2 Token"""
import json
import base64
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
if b"." not in decoded:
raise HTTPException(status_code=401, detail="Invalid token format")
payload, signature = decoded.split(b".", 1)
# 验证签名
if not self.sm2_crypt.verify(signature, payload, 'utf-8'):
raise HTTPException(status_code=401, detail="Signature verification failed")
data = json.loads(payload.decode('utf-8'))
# 检查过期
if datetime.utcnow().timestamp() > data.get("exp", 0):
raise HTTPException(status_code=401, detail="Token expired")
return data
# 7. 信创环境:达梦数据库用户管理
import dmPython
class DamengUserManager:
"""达梦数据库用户管理"""
def __init__(self, connection_string: str):
self.conn = dmPython.connect(connection_string)
self._init_tables()
def _init_tables(self):
"""初始化表"""
cursor = self.conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER IDENTITY(1,1) PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
password_hash VARCHAR(200) NOT NULL,
salt VARCHAR(100) NOT NULL,
role VARCHAR(20) DEFAULT 'user',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_active INTEGER DEFAULT 1
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS permissions (
id INTEGER IDENTITY(1,1) PRIMARY KEY,
user_id INTEGER,
permission VARCHAR(50),
FOREIGN KEY (user_id) REFERENCES users(id)
)
""")
self.conn.commit()
def create_user(self, username: str, password: str, role: str = "user"):
"""创建用户"""
pwd_hash, salt = hash_password(password)
cursor = self.conn.cursor()
cursor.execute(
"INSERT INTO users (username, password_hash, salt, role) VALUES (?, ?, ?, ?)",
(username, pwd_hash, salt, role)
)
self.conn.commit()
def verify_user(self, username: str, password: str) -> dict:
"""验证用户"""
cursor = self.conn.cursor()
cursor.execute(
"SELECT password_hash, salt, role, is_active FROM users WHERE username = ?",
(username,)
)
row = cursor.fetchone()
if not row:
raise HTTPException(status_code=401, detail="Invalid credentials")
pwd_hash, salt, role, is_active = row
if not is_active:
raise HTTPException(status_code=401, detail="User is disabled")
if not verify_password(password, pwd_hash, salt):
raise HTTPException(status_code=401, detail="Invalid credentials")
return {"username": username, "role": role}
print("✓ 认证授权配置完成")
---
b.数据加密
a.功能说明
加密敏感数据的传输和存储。使用HTTPS/TLS加密网络传输。使用AES等算法加密数据库中的敏感字段。信创环境使用国密算法(SM2、SM4)。数据加密保护用户隐私和业务机密。
b.代码示例
---
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import base64
import os
# 1. 对称加密(AES)
class DataEncryptor:
"""数据加密器"""
def __init__(self):
# 生成密钥(实际应安全存储)
self.key = Fernet.generate_key()
self.cipher = Fernet(self.key)
def encrypt(self, data: str) -> str:
"""加密数据"""
encrypted = self.cipher.encrypt(data.encode())
return base64.urlsafe_b64encode(encrypted).decode()
def decrypt(self, encrypted_data: str) -> str:
"""解密数据"""
encrypted = base64.urlsafe_b64decode(encrypted_data.encode())
decrypted = self.cipher.decrypt(encrypted)
return decrypted.decode()
encryptor = DataEncryptor()
@app.post("/encrypt")
def encrypt_data(data: str):
"""加密数据"""
encrypted = encryptor.encrypt(data)
return {"encrypted": encrypted}
@app.post("/decrypt")
def decrypt_data(encrypted: str):
"""解密数据"""
decrypted = encryptor.decrypt(encrypted)
return {"decrypted": decrypted}
# 2. SSL/TLS配置
# 生成自签名证书(开发环境)
# openssl req -x509 -newkey rsa:4096 -nodes -out cert.pem -keyout key.pem -days 365
# 启动HTTPS服务
# uvicorn server:app --host 0.0.0.0 --port 443 --ssl-keyfile key.pem --ssl-certfile cert.pem
# 3. 数据库字段加密
import dmPython
class EncryptedFieldHandler:
"""加密字段处理器"""
def __init__(self, encryptor: DataEncryptor):
self.encryptor = encryptor
def save_user(self, conn, username: str, email: str, phone: str):
"""保存用户(加密敏感字段)"""
cursor = conn.cursor()
# 加密邮箱和电话
encrypted_email = self.encryptor.encrypt(email)
encrypted_phone = self.encryptor.encrypt(phone)
cursor.execute("""
INSERT INTO users (username, email_encrypted, phone_encrypted)
VALUES (?, ?, ?)
""", (username, encrypted_email, encrypted_phone))
conn.commit()
def get_user(self, conn, username: str) -> dict:
"""获取用户(解密敏感字段)"""
cursor = conn.cursor()
cursor.execute(
"SELECT email_encrypted, phone_encrypted FROM users WHERE username = ?",
(username,)
)
row = cursor.fetchone()
if not row:
return None
# 解密
email = self.encryptor.decrypt(row[0])
phone = self.encryptor.decrypt(row[1])
return {
"username": username,
"email": email,
"phone": phone
}
# 4. 国密SM4加密(信创环境)
from gmssl import sm4, func
class SM4Encryptor:
"""国密SM4加密器"""
def __init__(self, key: bytes = None):
# SM4密钥(16字节)
self.key = key or os.urandom(16)
self.sm4_crypt = sm4.CryptSM4()
def encrypt(self, data: str) -> str:
"""SM4加密"""
# Padding
data_bytes = data.encode('utf-8')
padding_len = 16 - len(data_bytes) % 16
padded_data = data_bytes + bytes([padding_len] * padding_len)
# 加密
self.sm4_crypt.set_key(self.key, sm4.SM4_ENCRYPT)
encrypted = self.sm4_crypt.crypt_ecb(padded_data)
return base64.b64encode(encrypted).decode('utf-8')
def decrypt(self, encrypted_data: str) -> str:
"""SM4解密"""
encrypted = base64.b64decode(encrypted_data.encode('utf-8'))
# 解密
self.sm4_crypt.set_key(self.key, sm4.SM4_DECRYPT)
decrypted = self.sm4_crypt.crypt_ecb(encrypted)
# 去除padding
padding_len = decrypted[-1]
data = decrypted[:-padding_len]
return data.decode('utf-8')
# 5. 信创环境:达梦数据库 + SM4加密
class XinchuangSecureStorage:
"""信创安全存储"""
def __init__(self, dm_conn_str: str, sm4_key: bytes):
self.conn = dmPython.connect(dm_conn_str)
self.encryptor = SM4Encryptor(sm4_key)
self._init_tables()
def _init_tables(self):
"""初始化表"""
cursor = self.conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS secure_data (
id INTEGER IDENTITY(1,1) PRIMARY KEY,
user_id VARCHAR(50),
data_type VARCHAR(50),
encrypted_data TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
self.conn.commit()
def store_secure_data(self, user_id: str, data_type: str, data: str):
"""存储加密数据"""
encrypted = self.encryptor.encrypt(data)
cursor = self.conn.cursor()
cursor.execute("""
INSERT INTO secure_data (user_id, data_type, encrypted_data)
VALUES (?, ?, ?)
""", (user_id, data_type, encrypted))
self.conn.commit()
def get_secure_data(self, user_id: str, data_type: str) -> str:
"""获取解密数据"""
cursor = self.conn.cursor()
cursor.execute("""
SELECT encrypted_data FROM secure_data
WHERE user_id = ? AND data_type = ?
ORDER BY created_at DESC
LIMIT 1
""", (user_id, data_type))
row = cursor.fetchone()
if not row:
return None
return self.encryptor.decrypt(row[0])
# 6. 敏感数据脱敏
def mask_email(email: str) -> str:
"""邮箱脱敏"""
parts = email.split('@')
if len(parts) != 2:
return email
name, domain = parts
if len(name) <= 2:
masked_name = name[0] + '*'
else:
masked_name = name[0] + '*' * (len(name) - 2) + name[-1]
return f"{masked_name}@{domain}"
def mask_phone(phone: str) -> str:
"""手机号脱敏"""
if len(phone) <= 7:
return phone
return phone[:3] + '****' + phone[-4:]
def mask_id_card(id_card: str) -> str:
"""身份证号脱敏"""
if len(id_card) <= 8:
return id_card
return id_card[:6] + '********' + id_card[-4:]
@app.get("/user/profile")
def get_user_profile(user_id: str):
"""获取用户资料(脱敏)"""
# 从数据库获取
user = {
"email": "[email protected]",
"phone": "13812345678",
"id_card": "110101199001011234"
}
# 脱敏
return {
"email": mask_email(user["email"]),
"phone": mask_phone(user["phone"]),
"id_card": mask_id_card(user["id_card"])
}
print("✓ 数据加密配置完成")
---
02.安全防护
a.防护措施
a.功能说明
实施多层安全防护措施,抵御常见攻击。防止SQL注入、XSS、CSRF、DDoS等攻击。设置请求频率限制、IP黑白名单。定期安全审计和漏洞扫描。多层防护构建纵深防御体系。
b.代码示例
---
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from starlette.middleware.cors import CORSMiddleware
from collections import defaultdict
from datetime import datetime, timedelta
import re
app = FastAPI()
# 1. SQL注入防护
# 使用参数化查询,避免字符串拼接
def safe_query(conn, username: str):
"""安全查询"""
cursor = conn.cursor()
# ✓ 安全:参数化查询
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
# ✗ 不安全:字符串拼接
# cursor.execute(f"SELECT * FROM users WHERE username = '{username}'")
return cursor.fetchall()
# 2. XSS防护
import html
def sanitize_input(text: str) -> str:
"""清理输入(防XSS)"""
# HTML转义
return html.escape(text)
@app.post("/comment")
def post_comment(content: str):
"""发表评论(防XSS)"""
# 清理输入
safe_content = sanitize_input(content)
# 存储清理后的内容
return {"content": safe_content}
# 3. CSRF防护
import secrets
csrf_tokens = {}
@app.get("/csrf-token")
def get_csrf_token():
"""获取CSRF令牌"""
token = secrets.token_urlsafe(32)
csrf_tokens[token] = datetime.utcnow() + timedelta(hours=1)
return {"csrf_token": token}
def verify_csrf_token(token: str):
"""验证CSRF令牌"""
if token not in csrf_tokens:
raise HTTPException(status_code=403, detail="Invalid CSRF token")
if datetime.utcnow() > csrf_tokens[token]:
del csrf_tokens[token]
raise HTTPException(status_code=403, detail="CSRF token expired")
# 使用后删除
del csrf_tokens[token]
@app.post("/sensitive-action")
def sensitive_action(csrf_token: str, data: dict):
"""敏感操作(需要CSRF令牌)"""
verify_csrf_token(csrf_token)
return {"message": "Action completed"}
# 4. 限流(防DDoS)
from collections import deque
class RateLimiter:
"""限流器"""
def __init__(self, max_requests: int, window_seconds: int):
self.max_requests = max_requests
self.window = timedelta(seconds=window_seconds)
self.requests = defaultdict(deque)
def is_allowed(self, client_id: str) -> bool:
"""检查是否允许请求"""
now = datetime.utcnow()
# 清理过期请求
while self.requests[client_id] and now - self.requests[client_id][0] > self.window:
self.requests[client_id].popleft()
# 检查限制
if len(self.requests[client_id]) >= self.max_requests:
return False
# 记录请求
self.requests[client_id].append(now)
return True
rate_limiter = RateLimiter(max_requests=100, window_seconds=60)
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
"""限流中间件"""
client_ip = request.client.host
if not rate_limiter.is_allowed(client_ip):
raise HTTPException(
status_code=429,
detail="Too many requests"
)
return await call_next(request)
# 5. IP黑白名单
WHITELIST = ["192.168.1.0/24", "10.0.0.0/8"]
BLACKLIST = ["1.2.3.4", "5.6.7.8"]
from ipaddress import ip_address, ip_network
@app.middleware("http")
async def ip_filter_middleware(request: Request, call_next):
"""IP过滤中间件"""
client_ip = request.client.host
# 检查黑名单
if client_ip in BLACKLIST:
raise HTTPException(status_code=403, detail="IP blocked")
# 检查白名单(如果启用)
if WHITELIST:
allowed = any(
ip_address(client_ip) in ip_network(cidr)
for cidr in WHITELIST
)
if not allowed:
raise HTTPException(status_code=403, detail="IP not whitelisted")
return await call_next(request)
# 6. 输入验证
from pydantic import BaseModel, validator, Field
import re
class UserInput(BaseModel):
"""用户输入验证"""
username: str = Field(..., min_length=3, max_length=50)
email: str
phone: str
@validator('username')
def validate_username(cls, v):
"""验证用户名"""
if not re.match(r'^[a-zA-Z0-9_]+$', v):
raise ValueError('用户名只能包含字母、数字和下划线')
return v
@validator('email')
def validate_email(cls, v):
"""验证邮箱"""
if not re.match(r'^[^@]+@[^@]+\.[^@]+$', v):
raise ValueError('无效的邮箱格式')
return v
@validator('phone')
def validate_phone(cls, v):
"""验证手机号"""
if not re.match(r'^1[3-9]\d{9}$', v):
raise ValueError('无效的手机号')
return v
@app.post("/register")
def register(user: UserInput):
"""注册(输入验证)"""
return {"message": "注册成功"}
# 7. 安全响应头
@app.middleware("http")
async def security_headers_middleware(request: Request, call_next):
"""安全响应头中间件"""
response = await call_next(request)
# 防止点击劫持
response.headers["X-Frame-Options"] = "DENY"
# 防止MIME类型嗅探
response.headers["X-Content-Type-Options"] = "nosniff"
# XSS保护
response.headers["X-XSS-Protection"] = "1; mode=block"
# 内容安全策略
response.headers["Content-Security-Policy"] = "default-src 'self'"
# HSTS(强制HTTPS)
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response
# 8. 信创环境安全配置
# IP白名单(仅内网)
XINCHUANG_WHITELIST = [
"192.168.0.0/16",
"10.0.0.0/8",
"172.16.0.0/12"
]
@app.middleware("http")
async def xinchuang_security_middleware(request: Request, call_next):
"""信创安全中间件"""
client_ip = request.client.host
# 仅允许内网访问
allowed = any(
ip_address(client_ip) in ip_network(cidr)
for cidr in XINCHUANG_WHITELIST
)
if not allowed:
raise HTTPException(
status_code=403,
detail="仅允许内网访问"
)
response = await call_next(request)
# 添加信创环境标识
response.headers["X-Environment"] = "Xinchuang"
response.headers["X-Security-Level"] = "High"
return response
# 9. 请求审计日志
import dmPython
class SecurityAuditLogger:
"""安全审计日志"""
def __init__(self, dm_conn_str: str):
self.conn = dmPython.connect(dm_conn_str)
self._init_table()
def _init_table(self):
"""初始化审计表"""
cursor = self.conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS security_audit (
id BIGINT IDENTITY(1,1) PRIMARY KEY,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
client_ip VARCHAR(50),
user_id VARCHAR(50),
method VARCHAR(10),
path VARCHAR(200),
status_code INTEGER,
risk_level VARCHAR(20),
description TEXT
)
""")
self.conn.commit()
def log_request(self, request: Request, response_status: int, user_id: str = None):
"""记录请求"""
cursor = self.conn.cursor()
# 风险级别评估
risk_level = "low"
if response_status >= 400:
risk_level = "medium"
if response_status >= 500:
risk_level = "high"
cursor.execute("""
INSERT INTO security_audit
(client_ip, user_id, method, path, status_code, risk_level)
VALUES (?, ?, ?, ?, ?, ?)
""", (
request.client.host,
user_id or 'anonymous',
request.method,
request.url.path,
response_status,
risk_level
))
self.conn.commit()
# 10. 安全检查清单
"""
信创环境安全检查清单:
1. 网络安全
✓ 仅允许内网IP访问
✓ 使用VPN连接
✓ 配置防火墙规则
2. 认证授权
✓ 使用国密SM2 JWT
✓ 达梦数据库用户管理
✓ 基于角色的权限控制
3. 数据加密
✓ 使用国密SM4加密敏感数据
✓ 数据库字段加密
✓ 传输层TLS加密
4. 访问控制
✓ IP白名单限制
✓ API限流
✓ CSRF防护
5. 审计日志
✓ 请求审计记录
✓ 操作日志
✓ 安全事件追踪
6. 安全加固
✓ 最小权限原则
✓ 定期安全扫描
✓ 漏洞修复
"""
print("✓ 安全防护配置完成")
---
7 最佳实践
7.1 性能优化
01.缓存策略
a.响应缓存
a.功能说明
缓存API响应减少重复计算和LLM调用。使用Redis、Memcached等缓存中间件。设置合理的缓存过期时间。缓存是提升性能最有效的手段之一。
b.代码示例
---
from fastapi import FastAPI, Request
from functools import lru_cache
import hashlib
import json
import redis
app = FastAPI()
# 1. 内存缓存(LRU)
@lru_cache(maxsize=1000)
def cached_llm_call(prompt: str) -> str:
"""缓存的LLM调用"""
# 实际调用LLM
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI()
result = llm.invoke(prompt)
return result.content
@app.post("/cached_chat")
def cached_chat(question: str):
"""使用缓存的聊天"""
response = cached_llm_call(question)
return {"response": response}
# 2. Redis缓存
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
def get_cache_key(request: dict) -> str:
"""生成缓存键"""
content = json.dumps(request, sort_keys=True)
return hashlib.md5(content.encode()).hexdigest()
@app.post("/redis_cached_chat")
def redis_cached_chat(question: str):
"""Redis缓存的聊天"""
cache_key = get_cache_key({"question": question})
# 尝试从缓存获取
cached_response = redis_client.get(cache_key)
if cached_response:
return {"response": cached_response, "cached": True}
# LLM调用
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI()
result = llm.invoke(question)
# 存入缓存(1小时)
redis_client.setex(cache_key, 3600, result.content)
return {"response": result.content, "cached": False}
# 3. 智能缓存(根据相似度)
from sklearn.metrics.pairwise import cosine_similarity
from langchain.embeddings import OpenAIEmbeddings
class SmartCache:
"""智能缓存"""
def __init__(self, similarity_threshold=0.95):
self.cache = [] # [(embedding, response)]
self.embeddings_model = OpenAIEmbeddings()
self.threshold = similarity_threshold
def get(self, query: str):
"""获取缓存"""
if not self.cache:
return None
# 计算query的embedding
query_emb = self.embeddings_model.embed_query(query)
# 计算相似度
for cached_emb, cached_response in self.cache:
similarity = cosine_similarity([query_emb], [cached_emb])[0][0]
if similarity >= self.threshold:
return cached_response
return None
def set(self, query: str, response: str):
"""设置缓存"""
query_emb = self.embeddings_model.embed_query(query)
self.cache.append((query_emb, response))
# 限制缓存大小
if len(self.cache) > 1000:
self.cache = self.cache[-1000:]
smart_cache = SmartCache()
@app.post("/smart_cached_chat")
def smart_cached_chat(question: str):
"""智能缓存聊天"""
# 尝试从缓存获取
cached = smart_cache.get(question)
if cached:
return {"response": cached, "cached": True}
# LLM调用
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI()
result = llm.invoke(question)
# 存入缓存
smart_cache.set(question, result.content)
return {"response": result.content, "cached": False}
# 4. 分层缓存
class TieredCache:
"""分层缓存"""
def __init__(self):
self.memory_cache = {} # L1:内存
self.redis_client = redis.Redis(host='localhost', port=6379) # L2:Redis
def get(self, key: str):
"""获取缓存"""
# L1:内存
if key in self.memory_cache:
return self.memory_cache[key]
# L2:Redis
value = self.redis_client.get(key)
if value:
# 提升到L1
self.memory_cache[key] = value.decode()
return value.decode()
return None
def set(self, key: str, value: str, ttl: int = 3600):
"""设置缓存"""
# L1:内存
self.memory_cache[key] = value
# L2:Redis
self.redis_client.setex(key, ttl, value)
tiered_cache = TieredCache()
# 5. 缓存预热
def warm_up_cache():
"""预热缓存"""
common_questions = [
"什么是LangServe?",
"如何使用LangChain?",
"什么是RAG?"
]
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI()
for question in common_questions:
cache_key = get_cache_key({"question": question})
# 检查是否已缓存
if not redis_client.exists(cache_key):
result = llm.invoke(question)
redis_client.setex(cache_key, 86400, result.content) # 24小时
print(f"✓ 缓存预热完成,预热{len(common_questions)}个问题")
@app.on_event("startup")
async def startup_warm_up():
"""启动时预热缓存"""
import asyncio
asyncio.create_task(asyncio.to_thread(warm_up_cache))
# 6. 缓存失效策略
class CacheInvalidator:
"""缓存失效管理"""
def __init__(self, redis_client):
self.redis = redis_client
def invalidate_by_pattern(self, pattern: str):
"""按模式失效"""
keys = self.redis.keys(pattern)
if keys:
self.redis.delete(*keys)
return len(keys)
def invalidate_by_tags(self, tags: list):
"""按标签失效"""
for tag in tags:
keys = self.redis.smembers(f"tag:{tag}")
if keys:
self.redis.delete(*keys)
self.redis.delete(f"tag:{tag}")
invalidator = CacheInvalidator(redis_client)
@app.post("/admin/invalidate_cache")
def invalidate_cache(pattern: str = "*"):
"""失效缓存"""
count = invalidator.invalidate_by_pattern(pattern)
return {"message": f"失效{count}个缓存项"}
# 7. 信创环境:达梦数据库缓存
import dmPython
class DamengCache:
"""达梦数据库缓存"""
def __init__(self, conn_str: str):
self.conn = dmPython.connect(conn_str)
self._init_table()
def _init_table(self):
"""初始化缓存表"""
cursor = self.conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS cache_store (
cache_key VARCHAR(100) PRIMARY KEY,
cache_value TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
""")
# 创建过期时间索引
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_expires_at ON cache_store(expires_at)
""")
self.conn.commit()
def get(self, key: str):
"""获取缓存"""
cursor = self.conn.cursor()
cursor.execute("""
SELECT cache_value FROM cache_store
WHERE cache_key = ? AND expires_at > CURRENT_TIMESTAMP
""", (key,))
row = cursor.fetchone()
return row[0] if row else None
def set(self, key: str, value: str, ttl_seconds: int = 3600):
"""设置缓存"""
from datetime import datetime, timedelta
cursor = self.conn.cursor()
expires_at = datetime.now() + timedelta(seconds=ttl_seconds)
# UPSERT
cursor.execute("""
MERGE INTO cache_store USING (
SELECT ? AS cache_key FROM DUAL
) tmp ON (cache_store.cache_key = tmp.cache_key)
WHEN MATCHED THEN
UPDATE SET cache_value = ?, expires_at = ?
WHEN NOT MATCHED THEN
INSERT (cache_key, cache_value, expires_at)
VALUES (?, ?, ?)
""", (key, value, expires_at, key, value, expires_at))
self.conn.commit()
def cleanup_expired(self):
"""清理过期缓存"""
cursor = self.conn.cursor()
cursor.execute("""
DELETE FROM cache_store
WHERE expires_at < CURRENT_TIMESTAMP
""")
self.conn.commit()
return cursor.rowcount
# 8. 缓存监控
class CacheMonitor:
"""缓存监控"""
def __init__(self):
self.hits = 0
self.misses = 0
def record_hit(self):
"""记录命中"""
self.hits += 1
def record_miss(self):
"""记录未命中"""
self.misses += 1
def get_stats(self):
"""获取统计"""
total = self.hits + self.misses
hit_rate = self.hits / total if total > 0 else 0
return {
"hits": self.hits,
"misses": self.misses,
"total": total,
"hit_rate": hit_rate
}
cache_monitor = CacheMonitor()
@app.get("/cache/stats")
def get_cache_stats():
"""获取缓存统计"""
return cache_monitor.get_stats()
print("✓ 缓存策略配置完成")
---
b.异步处理
a.功能说明
使用异步处理提升并发性能。异步调用LLM、数据库等IO操作。使用asyncio、异步框架。异步处理显著提升系统吞吐量。
b.代码示例
---
from fastapi import FastAPI
import asyncio
from langchain.chat_models import ChatOpenAI
from langchain.callbacks import AsyncCallbackManager
app = FastAPI()
# 1. 异步LLM调用
@app.post("/async_chat")
async def async_chat(question: str):
"""异步聊天"""
llm = ChatOpenAI()
# 异步调用
result = await llm.ainvoke(question)
return {"response": result.content}
# 2. 并发多个LLM调用
@app.post("/parallel_chat")
async def parallel_chat(questions: list[str]):
"""并发处理多个问题"""
llm = ChatOpenAI()
# 并发调用
tasks = [llm.ainvoke(q) for q in questions]
results = await asyncio.gather(*tasks)
return {
"responses": [r.content for r in results]
}
# 3. 异步数据库操作
import aiomysql
async def async_db_query(query: str):
"""异步数据库查询"""
conn = await aiomysql.connect(
host='localhost',
user='user',
password='password',
db='database'
)
async with conn.cursor() as cursor:
await cursor.execute(query)
result = await cursor.fetchall()
conn.close()
return result
# 4. 异步Redis操作
import aioredis
async def async_redis_get(key: str):
"""异步Redis获取"""
redis = await aioredis.create_redis_pool('redis://localhost')
value = await redis.get(key)
redis.close()
await redis.wait_closed()
return value
# 5. 异步HTTP请求
import httpx
async def async_http_call(url: str):
"""异步HTTP调用"""
async with httpx.AsyncClient() as client:
response = await client.get(url)
return response.json()
# 6. 任务队列(Celery)
from celery import Celery
celery_app = Celery('langserve', broker='redis://localhost:6379')
@celery_app.task
def process_long_task(data: dict):
"""处理长时间任务"""
# 执行LLM调用、数据处理等
import time
time.sleep(10) # 模拟长时间任务
return {"status": "completed", "result": "..."}
@app.post("/submit_task")
def submit_task(data: dict):
"""提交任务"""
task = process_long_task.delay(data)
return {
"task_id": task.id,
"status": "submitted"
}
@app.get("/task_status/{task_id}")
def get_task_status(task_id: str):
"""获取任务状态"""
task = celery_app.AsyncResult(task_id)
return {
"task_id": task_id,
"status": task.status,
"result": task.result if task.ready() else None
}
# 7. 流式异步处理
from fastapi.responses import StreamingResponse
@app.post("/async_stream_chat")
async def async_stream_chat(question: str):
"""异步流式聊天"""
async def generate():
llm = ChatOpenAI()
async for chunk in llm.astream(question):
yield f"data: {chunk.content}\\n\\n"
return StreamingResponse(generate(), media_type="text/event-stream")
# 8. 信创环境:异步达梦数据库
# 注意:达梦数据库Python驱动可能不支持异步,需要使用线程池
from concurrent.futures import ThreadPoolExecutor
import dmPython
executor = ThreadPoolExecutor(max_workers=10)
async def async_dameng_query(query: str):
"""异步达梦数据库查询"""
loop = asyncio.get_event_loop()
def sync_query():
conn = dmPython.connect(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236"
)
cursor = conn.cursor()
cursor.execute(query)
result = cursor.fetchall()
conn.close()
return result
# 在线程池中执行
result = await loop.run_in_executor(executor, sync_query)
return result
@app.get("/xinchuang/async_query")
async def xinchuang_async_query(sql: str):
"""信创异步查询"""
result = await async_dameng_query(sql)
return {"result": result}
print("✓ 异步处理配置完成")
---
02.资源优化
a.连接池
a.功能说明
使用连接池管理数据库和HTTP连接,减少连接开销。复用连接提升性能。配置合理的池大小和超时。连接池是高性能系统的标配。
b.代码示例
---
from fastapi import FastAPI
import httpx
from sqlalchemy import create_engine
from sqlalchemy.pool import QueuePool
import redis
from redis import ConnectionPool
app = FastAPI()
# 1. HTTP连接池
http_client = httpx.Client(
limits=httpx.Limits(
max_keepalive_connections=20,
max_connections=100,
keepalive_expiry=30.0
),
timeout=httpx.Timeout(30.0)
)
@app.get("/http_pooled")
def http_pooled_request(url: str):
"""使用连接池的HTTP请求"""
response = http_client.get(url)
return response.json()
# 异步HTTP连接池
async_http_client = httpx.AsyncClient(
limits=httpx.Limits(
max_keepalive_connections=20,
max_connections=100
)
)
@app.get("/async_http_pooled")
async def async_http_pooled_request(url: str):
"""异步HTTP连接池"""
response = await async_http_client.get(url)
return response.json()
# 2. Redis连接池
redis_pool = ConnectionPool(
host='localhost',
port=6379,
max_connections=50,
decode_responses=True
)
redis_client = redis.Redis(connection_pool=redis_pool)
@app.get("/redis_pooled")
def redis_pooled_get(key: str):
"""使用连接池的Redis操作"""
value = redis_client.get(key)
return {"value": value}
# 3. 数据库连接池(SQLAlchemy)
engine = create_engine(
'postgresql://user:password@localhost/dbname',
poolclass=QueuePool,
pool_size=20,
max_overflow=10,
pool_timeout=30,
pool_recycle=3600,
pool_pre_ping=True
)
@app.get("/db_pooled")
def db_pooled_query():
"""使用连接池的数据库查询"""
with engine.connect() as conn:
result = conn.execute("SELECT 1")
return {"result": result.fetchone()}
# 4. 信创环境:达梦数据库连接池
import dmPython
from queue import Queue
class DamengConnectionPool:
"""达梦数据库连接池"""
def __init__(self, conn_str: str, pool_size: int = 10):
self.conn_str = conn_str
self.pool_size = pool_size
self.pool = Queue(maxsize=pool_size)
# 初始化连接池
for _ in range(pool_size):
conn = dmPython.connect(conn_str)
self.pool.put(conn)
def get_connection(self):
"""获取连接"""
return self.pool.get(timeout=10)
def return_connection(self, conn):
"""归还连接"""
self.pool.put(conn)
def close_all(self):
"""关闭所有连接"""
while not self.pool.empty():
conn = self.pool.get()
conn.close()
dm_pool = DamengConnectionPool(
"user=SYSDBA;password=SYSDBA;server=localhost;port=5236",
pool_size=20
)
@app.get("/xinchuang/pooled_query")
def xinchuang_pooled_query(sql: str):
"""信创环境连接池查询"""
conn = dm_pool.get_connection()
try:
cursor = conn.cursor()
cursor.execute(sql)
result = cursor.fetchall()
return {"result": result}
finally:
dm_pool.return_connection(conn)
# 5. 连接池监控
class PoolMonitor:
"""连接池监控"""
def __init__(self, pool):
self.pool = pool
def get_stats(self):
"""获取统计"""
return {
"size": self.pool.qsize(),
"max_size": self.pool.maxsize,
"available": self.pool.qsize(),
"in_use": self.pool.maxsize - self.pool.qsize()
}
pool_monitor = PoolMonitor(dm_pool.pool)
@app.get("/pool/stats")
def get_pool_stats():
"""获取连接池统计"""
return pool_monitor.get_stats()
print("✓ 连接池配置完成")
---
7.2 版本管理
01.API版本控制
a.路径版本
a.功能说明
在API路径中包含版本号,明确标识API版本。支持多版本并存,平滑升级。客户端可以选择使用的版本。路径版本是最常见的版本控制方式。
b.代码示例
---
from fastapi import FastAPI, APIRouter
from langserve import add_routes
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
app = FastAPI()
# 1. 基础路径版本
# API v1
v1_router = APIRouter(prefix="/v1")
@v1_router.post("/chat")
def chat_v1(question: str):
"""聊天v1"""
llm = ChatOpenAI(model="gpt-3.5-turbo")
result = llm.invoke(question)
return {"response": result.content}
# API v2(新版本,增加上下文)
v2_router = APIRouter(prefix="/v2")
@v2_router.post("/chat")
def chat_v2(question: str, context: str = ""):
"""聊天v2"""
prompt = f"上下文:{context}\\n问题:{question}" if context else question
llm = ChatOpenAI(model="gpt-4")
result = llm.invoke(prompt)
return {
"response": result.content,
"version": "2.0",
"model": "gpt-4"
}
app.include_router(v1_router)
app.include_router(v2_router)
# 2. LangServe版本路由
# v1 Chain
llm_v1 = ChatOpenAI(model="gpt-3.5-turbo")
prompt_v1 = ChatPromptTemplate.from_template("回答:{question}")
chain_v1 = prompt_v1 | llm_v1
add_routes(app, chain_v1, path="/v1/langserve/chat")
# v2 Chain
llm_v2 = ChatOpenAI(model="gpt-4")
prompt_v2 = ChatPromptTemplate.from_template(
"上下文:{context}\\n问题:{question}"
)
chain_v2 = prompt_v2 | llm_v2
add_routes(app, chain_v2, path="/v2/langserve/chat")
# 3. 版本信息端点
@app.get("/version")
def get_api_versions():
"""获取API版本信息"""
return {
"versions": [
{
"version": "v1",
"status": "deprecated",
"endpoints": ["/v1/chat", "/v1/langserve/chat"],
"deprecation_date": "2024-12-31"
},
{
"version": "v2",
"status": "stable",
"endpoints": ["/v2/chat", "/v2/langserve/chat"],
"features": ["上下文支持", "GPT-4模型"]
}
],
"latest": "v2"
}
# 4. 版本弃用警告
from fastapi import Request
@app.middleware("http")
async def deprecation_warning_middleware(request: Request, call_next):
"""版本弃用警告中间件"""
response = await call_next(request)
# 检测v1 API调用
if request.url.path.startswith("/v1/"):
response.headers["X-API-Deprecation-Warning"] = (
"API v1已弃用,将于2024-12-31下线,请升级到v2"
)
response.headers["X-API-Deprecation-Date"] = "2024-12-31"
response.headers["X-API-Migration-Guide"] = "/docs/v2-migration"
return response
# 5. 版本降级处理
@app.post("/v2/chat")
def chat_v2_with_fallback(question: str, context: str = ""):
"""v2聊天(带降级)"""
try:
# 尝试使用v2(GPT-4)
llm = ChatOpenAI(model="gpt-4")
result = llm.invoke(question)
return {
"response": result.content,
"version": "2.0",
"model": "gpt-4"
}
except Exception as e:
# 降级到v1(GPT-3.5)
logger.warning(f"v2失败,降级到v1:{e}")
llm = ChatOpenAI(model="gpt-3.5-turbo")
result = llm.invoke(question)
return {
"response": result.content,
"version": "2.0-fallback",
"model": "gpt-3.5-turbo",
"fallback": True
}
# 6. 版本兼容层
class VersionAdapter:
"""版本适配器"""
@staticmethod
def adapt_v1_to_v2(v1_request: dict) -> dict:
"""适配v1请求到v2格式"""
return {
"question": v1_request.get("question"),
"context": "" # v1没有context,默认为空
}
@staticmethod
def adapt_v2_to_v1(v2_response: dict) -> dict:
"""适配v2响应到v1格式"""
return {
"response": v2_response.get("response")
# 去除v2特有字段
}
# 7. 信创环境版本配置
XINCHUANG_API_VERSIONS = {
"v1": {
"llm": "ollama",
"model": "qwen:7b",
"db": "dameng"
},
"v2": {
"llm": "ollama",
"model": "qwen:14b",
"db": "dameng",
"features": ["streaming", "context"]
}
}
@app.get("/xinchuang/version")
def get_xinchuang_versions():
"""信创环境版本信息"""
return {
"environment": "xinchuang",
"versions": XINCHUANG_API_VERSIONS,
"os": "Kylin",
"llm_provider": "Ollama"
}
# 8. 版本使用统计
from collections import defaultdict
version_stats = defaultdict(int)
@app.middleware("http")
async def version_stats_middleware(request: Request, call_next):
"""版本统计中间件"""
# 提取版本
version = None
if request.url.path.startswith("/v1/"):
version = "v1"
elif request.url.path.startswith("/v2/"):
version = "v2"
if version:
version_stats[version] += 1
response = await call_next(request)
return response
@app.get("/admin/version_stats")
def get_version_stats():
"""获取版本统计"""
total = sum(version_stats.values())
return {
"stats": dict(version_stats),
"total": total,
"percentages": {
v: (count / total * 100) if total > 0 else 0
for v, count in version_stats.items()
}
}
print("✓ API版本控制配置完成")
---
b.版本迁移
a.功能说明
提供版本迁移指南和工具,帮助客户端升级。自动化迁移脚本。记录迁移进度和问题。平滑的版本迁移保障系统稳定升级。
b.代码示例
---
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
# 1. 迁移指南API
@app.get("/docs/migration/v1-to-v2")
def get_migration_guide():
"""获取迁移指南"""
return {
"title": "API v1到v2迁移指南",
"changes": [
{
"type": "breaking",
"description": "响应格式变更",
"before": {"response": "text"},
"after": {"response": "text", "version": "2.0", "model": "gpt-4"}
},
{
"type": "feature",
"description": "新增context参数",
"example": {
"question": "问题",
"context": "上下文"
}
}
],
"migration_steps": [
"1. 更新请求格式,添加context字段(可选)",
"2. 更新响应处理逻辑,支持新字段",
"3. 测试迁移后的功能",
"4. 更新客户端版本号"
],
"code_examples": {
"python": """
# v1
response = requests.post('/v1/chat', json={'question': 'test'})
result = response.json()['response']
# v2
response = requests.post('/v2/chat', json={
'question': 'test',
'context': 'context' # 新增
})
result = response.json()['response']
"""
}
}
# 2. 迁移检查器
class MigrationChecker:
"""迁移检查器"""
def __init__(self):
self.issues = []
def check_request(self, version: str, request_data: dict) -> list:
"""检查请求兼容性"""
issues = []
if version == "v2":
# 检查v2必需字段
if 'question' not in request_data:
issues.append("缺少必需字段:question")
# 警告:建议添加context
if 'context' not in request_data:
issues.append("建议添加context字段以获得更好效果")
return issues
checker = MigrationChecker()
@app.post("/migrate/check")
def check_migration_compatibility(version: str, request_data: dict):
"""检查迁移兼容性"""
issues = checker.check_request(version, request_data)
return {
"compatible": len(issues) == 0,
"issues": issues,
"recommendations": [
"升级到v2以使用最新功能",
"添加context参数提升回答质量"
]
}
# 3. 自动迁移工具
class AutoMigrator:
"""自动迁移工具"""
def migrate_v1_to_v2(self, v1_data: dict) -> dict:
"""自动迁移v1数据到v2格式"""
v2_data = {
"question": v1_data.get("question"),
"context": "" # 默认空context
}
return v2_data
def migrate_response_v2_to_v1(self, v2_response: dict) -> dict:
"""v2响应降级到v1格式"""
return {
"response": v2_response.get("response")
}
migrator = AutoMigrator()
@app.post("/migrate/auto")
def auto_migrate_request(from_version: str, to_version: str, data: dict):
"""自动迁移请求"""
if from_version == "v1" and to_version == "v2":
migrated = migrator.migrate_v1_to_v2(data)
return {
"migrated_data": migrated,
"warnings": ["请添加合适的context以获得更好效果"]
}
return {"error": "不支持的迁移路径"}
# 4. 迁移进度追踪
from typing import Dict
migration_progress: Dict[str, dict] = {}
class MigrationTracker:
"""迁移进度追踪"""
def __init__(self):
self.migrations = {}
def start_migration(self, client_id: str, from_v: str, to_v: str):
"""开始迁移"""
self.migrations[client_id] = {
"from_version": from_v,
"to_version": to_v,
"status": "in_progress",
"started_at": datetime.now().isoformat(),
"progress": 0
}
def update_progress(self, client_id: str, progress: int):
"""更新进度"""
if client_id in self.migrations:
self.migrations[client_id]["progress"] = progress
def complete_migration(self, client_id: str):
"""完成迁移"""
if client_id in self.migrations:
self.migrations[client_id].update({
"status": "completed",
"completed_at": datetime.now().isoformat(),
"progress": 100
})
tracker = MigrationTracker()
@app.post("/migrate/start")
def start_migration(client_id: str, from_version: str, to_version: str):
"""开始迁移"""
tracker.start_migration(client_id, from_version, to_version)
return {
"message": "迁移已启动",
"migration_id": client_id
}
@app.get("/migrate/progress/{client_id}")
def get_migration_progress(client_id: str):
"""获取迁移进度"""
if client_id in tracker.migrations:
return tracker.migrations[client_id]
return {"error": "未找到迁移记录"}
# 5. 批量迁移
@app.post("/migrate/batch")
async def batch_migrate(client_ids: list[str], from_version: str, to_version: str):
"""批量迁移"""
results = []
for client_id in client_ids:
try:
tracker.start_migration(client_id, from_version, to_version)
# 模拟迁移过程
import asyncio
await asyncio.sleep(1)
tracker.update_progress(client_id, 50)
await asyncio.sleep(1)
tracker.complete_migration(client_id)
results.append({
"client_id": client_id,
"status": "success"
})
except Exception as e:
results.append({
"client_id": client_id,
"status": "failed",
"error": str(e)
})
return {"results": results}
# 6. 信创环境迁移
@app.post("/xinchuang/migrate")
def xinchuang_migrate(from_version: str, to_version: str):
"""信创环境迁移"""
changes = []
if from_version == "v1" and to_version == "v2":
changes = [
{
"component": "LLM",
"from": "qwen:7b",
"to": "qwen:14b",
"reason": "性能提升"
},
{
"component": "Database",
"from": "dameng",
"to": "dameng",
"reason": "无变化"
},
{
"component": "Features",
"new": ["streaming", "context"],
"reason": "新增功能"
}
]
return {
"environment": "xinchuang",
"changes": changes,
"estimated_time": "30分钟",
"downtime_required": False
}
# 7. 回滚支持
class RollbackManager:
"""回滚管理器"""
def __init__(self):
self.backups = {}
def create_backup(self, client_id: str, version: str, config: dict):
"""创建备份"""
self.backups[client_id] = {
"version": version,
"config": config,
"backed_up_at": datetime.now().isoformat()
}
def rollback(self, client_id: str) -> dict:
"""回滚"""
if client_id not in self.backups:
raise Exception("没有备份可回滚")
backup = self.backups[client_id]
return backup
rollback_mgr = RollbackManager()
@app.post("/migrate/rollback/{client_id}")
def rollback_migration(client_id: str):
"""回滚迁移"""
try:
backup = rollback_mgr.rollback(client_id)
return {
"message": "回滚成功",
"restored_version": backup["version"],
"restored_at": datetime.now().isoformat()
}
except Exception as e:
return {"error": str(e)}
print("✓ 版本迁移配置完成")
---
02.配置管理
a.环境配置
a.功能说明
管理不同环境(开发、测试、生产)的配置。使用环境变量、配置文件。支持配置热更新。良好的配置管理提升运维效率。
b.代码示例
---
from fastapi import FastAPI
from pydantic import BaseSettings
import os
from typing import Optional
# 1. Pydantic配置模型
class Settings(BaseSettings):
"""应用配置"""
# 环境
environment: str = "development"
# 服务
host: str = "0.0.0.0"
port: int = 8000
# LLM
openai_api_key: Optional[str] = None
llm_model: str = "gpt-3.5-turbo"
llm_temperature: float = 0.7
# 数据库
db_host: str = "localhost"
db_port: int = 5432
db_user: str = "user"
db_password: str = "password"
db_name: str = "database"
# Redis
redis_host: str = "localhost"
redis_port: int = 6379
# 日志
log_level: str = "INFO"
class Config:
env_file = ".env"
settings = Settings()
# 2. 环境特定配置
class DevelopmentSettings(Settings):
"""开发环境配置"""
environment: str = "development"
log_level: str = "DEBUG"
llm_model: str = "gpt-3.5-turbo"
class ProductionSettings(Settings):
"""生产环境配置"""
environment: str = "production"
log_level: str = "WARNING"
llm_model: str = "gpt-4"
class Config:
env_file = ".env.production"
def get_settings() -> Settings:
"""获取环境配置"""
env = os.getenv("ENVIRONMENT", "development")
if env == "production":
return ProductionSettings()
elif env == "testing":
return Settings(environment="testing")
else:
return DevelopmentSettings()
current_settings = get_settings()
# 3. 配置端点
app = FastAPI()
@app.get("/config")
def get_config():
"""获取当前配置"""
return {
"environment": current_settings.environment,
"llm_model": current_settings.llm_model,
"log_level": current_settings.log_level
# 不暴露敏感信息
}
# 4. 配置热更新
class ConfigManager:
"""配置管理器"""
def __init__(self):
self.settings = get_settings()
def reload(self):
"""重新加载配置"""
self.settings = get_settings()
print(f"配置已重新加载:{self.settings.environment}")
def update(self, key: str, value: any):
"""更新配置"""
setattr(self.settings, key, value)
print(f"配置已更新:{key}={value}")
config_mgr = ConfigManager()
@app.post("/admin/config/reload")
def reload_config():
"""重新加载配置"""
config_mgr.reload()
return {"message": "配置已重新加载"}
@app.post("/admin/config/update")
def update_config(key: str, value: str):
"""更新配置"""
config_mgr.update(key, value)
return {"message": f"已更新{key}"}
# 5. 信创环境配置
class XinchuangSettings(Settings):
"""信创环境配置"""
environment: str = "xinchuang"
# 达梦数据库
dm_host: str = "localhost"
dm_port: int = 5236
dm_user: str = "SYSDBA"
dm_password: str = "SYSDBA"
# Ollama
ollama_host: str = "localhost"
ollama_port: int = 11434
ollama_model: str = "qwen:7b"
# 国密配置
sm2_private_key: Optional[str] = None
sm2_public_key: Optional[str] = None
sm4_key: Optional[str] = None
class Config:
env_file = ".env.xinchuang"
@app.get("/xinchuang/config")
def get_xinchuang_config():
"""获取信创配置"""
xc_settings = XinchuangSettings()
return {
"environment": "xinchuang",
"database": {
"type": "dameng",
"host": xc_settings.dm_host,
"port": xc_settings.dm_port
},
"llm": {
"provider": "ollama",
"host": xc_settings.ollama_host,
"model": xc_settings.ollama_model
},
"encryption": {
"algorithm": "SM2/SM4"
}
}
# 6. 配置验证
def validate_config(settings: Settings) -> list:
"""验证配置"""
errors = []
# 验证LLM配置
if not settings.openai_api_key and settings.environment != "xinchuang":
errors.append("缺少OPENAI_API_KEY")
# 验证数据库配置
if not settings.db_host:
errors.append("缺少数据库主机配置")
# 验证温度范围
if not 0 <= settings.llm_temperature <= 2:
errors.append("LLM温度必须在0-2之间")
return errors
@app.get("/config/validate")
def validate_current_config():
"""验证当前配置"""
errors = validate_config(current_settings)
return {
"valid": len(errors) == 0,
"errors": errors
}
print("✓ 配置管理配置完成")
---
7.3 灰度发布
01.灰度策略
a.流量分配
a.功能说明
按比例分配流量到新旧版本,逐步验证新版本稳定性。支持按用户、按IP、按随机等分配策略。动态调整流量比例。灰度发布降低新版本上线风险。
b.代码示例
---
from fastapi import FastAPI, Request
import random
import hashlib
app = FastAPI()
# 1. 随机流量分配
CANARY_PERCENTAGE = 10 # 10%流量到新版本
@app.middleware("http")
async def canary_routing_middleware(request: Request, call_next):
"""金丝雀路由中间件"""
# 随机决定版本
use_canary = random.random() < (CANARY_PERCENTAGE / 100)
# 添加版本标记
request.state.version = "canary" if use_canary else "stable"
response = await call_next(request)
# 添加响应头
response.headers["X-Version"] = request.state.version
return response
@app.post("/chat")
def chat(request: Request, question: str):
"""聊天(灰度)"""
version = request.state.version
if version == "canary":
# 新版本逻辑
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model="gpt-4")
result = llm.invoke(question)
return {
"response": result.content,
"version": "canary",
"model": "gpt-4"
}
else:
# 稳定版本逻辑
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model="gpt-3.5-turbo")
result = llm.invoke(question)
return {
"response": result.content,
"version": "stable",
"model": "gpt-3.5-turbo"
}
# 2. 用户ID灰度
def is_canary_user(user_id: str, percentage: int = 10) -> bool:
"""判断是否为灰度用户"""
# 使用hash确保同一用户总是得到相同结果
hash_value = int(hashlib.md5(user_id.encode()).hexdigest(), 16)
return (hash_value % 100) < percentage
@app.post("/user_canary_chat")
def user_canary_chat(user_id: str, question: str):
"""基于用户的灰度聊天"""
is_canary = is_canary_user(user_id, CANARY_PERCENTAGE)
version = "canary" if is_canary else "stable"
model = "gpt-4" if is_canary else "gpt-3.5-turbo"
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model=model)
result = llm.invoke(question)
return {
"response": result.content,
"version": version,
"user_id": user_id
}
# 3. IP灰度
def is_canary_ip(ip: str, percentage: int = 10) -> bool:
"""判断是否为灰度IP"""
hash_value = int(hashlib.md5(ip.encode()).hexdigest(), 16)
return (hash_value % 100) < percentage
@app.middleware("http")
async def ip_based_canary(request: Request, call_next):
"""基于IP的灰度"""
client_ip = request.client.host
is_canary = is_canary_ip(client_ip, CANARY_PERCENTAGE)
request.state.version = "canary" if is_canary else "stable"
response = await call_next(request)
response.headers["X-Version"] = request.state.version
return response
# 4. 白名单灰度
CANARY_WHITELIST = ["user1", "user2", "user3"] # 始终使用新版本的用户
@app.post("/whitelist_canary_chat")
def whitelist_canary_chat(user_id: str, question: str):
"""白名单灰度聊天"""
# 白名单用户使用新版本
if user_id in CANARY_WHITELIST:
version = "canary"
model = "gpt-4"
else:
# 其他用户按比例灰度
is_canary = is_canary_user(user_id, CANARY_PERCENTAGE)
version = "canary" if is_canary else "stable"
model = "gpt-4" if is_canary else "gpt-3.5-turbo"
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model=model)
result = llm.invoke(question)
return {
"response": result.content,
"version": version
}
# 5. 动态调整流量比例
class CanaryController:
"""灰度控制器"""
def __init__(self):
self.percentage = 10
self.enabled = True
def set_percentage(self, percentage: int):
"""设置灰度比例"""
if 0 <= percentage <= 100:
self.percentage = percentage
print(f"灰度比例已调整为{percentage}%")
def enable(self):
"""启用灰度"""
self.enabled = True
def disable(self):
"""禁用灰度"""
self.enabled = False
def should_use_canary(self, user_id: str = None) -> bool:
"""判断是否使用灰度版本"""
if not self.enabled:
return False
if user_id:
return is_canary_user(user_id, self.percentage)
else:
return random.random() < (self.percentage / 100)
canary_controller = CanaryController()
@app.post("/admin/canary/set_percentage")
def set_canary_percentage(percentage: int):
"""设置灰度比例"""
canary_controller.set_percentage(percentage)
return {"message": f"灰度比例已设置为{percentage}%"}
@app.post("/admin/canary/enable")
def enable_canary():
"""启用灰度"""
canary_controller.enable()
return {"message": "灰度已启用"}
@app.post("/admin/canary/disable")
def disable_canary():
"""禁用灰度(全部流量到稳定版)"""
canary_controller.disable()
return {"message": "灰度已禁用"}
# 6. 灰度监控
from collections import defaultdict
canary_stats = {
"stable": {"requests": 0, "errors": 0},
"canary": {"requests": 0, "errors": 0}
}
@app.middleware("http")
async def canary_monitoring_middleware(request: Request, call_next):
"""灰度监控中间件"""
version = getattr(request.state, "version", "stable")
canary_stats[version]["requests"] += 1
try:
response = await call_next(request)
if response.status_code >= 500:
canary_stats[version]["errors"] += 1
return response
except Exception as e:
canary_stats[version]["errors"] += 1
raise
@app.get("/canary/stats")
def get_canary_stats():
"""获取灰度统计"""
stats = {}
for version in ["stable", "canary"]:
data = canary_stats[version]
total = data["requests"]
errors = data["errors"]
stats[version] = {
"requests": total,
"errors": errors,
"error_rate": errors / total if total > 0 else 0
}
return stats
# 7. 自动回滚
class AutoRollback:
"""自动回滚"""
def __init__(self, error_threshold: float = 0.05):
self.error_threshold = error_threshold
def check_rollback(self, stats: dict) -> bool:
"""检查是否需要回滚"""
canary_error_rate = stats["canary"]["error_rate"]
stable_error_rate = stats["stable"]["error_rate"]
# 如果灰度版本错误率显著高于稳定版
if canary_error_rate > stable_error_rate + self.error_threshold:
return True
# 如果灰度版本错误率超过阈值
if canary_error_rate > self.error_threshold:
return True
return False
auto_rollback = AutoRollback(error_threshold=0.05)
@app.get("/canary/check_rollback")
def check_rollback():
"""检查是否需要回滚"""
stats = get_canary_stats()
should_rollback = auto_rollback.check_rollback(stats)
if should_rollback:
canary_controller.disable()
return {
"rollback": True,
"reason": "灰度版本错误率过高",
"stats": stats
}
return {
"rollback": False,
"stats": stats
}
# 8. 信创环境灰度
class XinchuangCanary:
"""信创环境灰度"""
def __init__(self):
self.percentage = 10
self.stable_config = {
"llm_model": "qwen:7b",
"db": "dameng_primary"
}
self.canary_config = {
"llm_model": "qwen:14b", # 新版本模型
"db": "dameng_primary"
}
def get_config(self, user_id: str) -> dict:
"""获取配置"""
is_canary = is_canary_user(user_id, self.percentage)
if is_canary:
return {
**self.canary_config,
"version": "canary"
}
else:
return {
**self.stable_config,
"version": "stable"
}
xc_canary = XinchuangCanary()
@app.post("/xinchuang/canary_chat")
def xinchuang_canary_chat(user_id: str, question: str):
"""信创环境灰度聊天"""
config = xc_canary.get_config(user_id)
from langchain.llms import Ollama
llm = Ollama(
model=config["llm_model"],
base_url="http://localhost:11434"
)
result = llm.invoke(question)
return {
"response": result,
"version": config["version"],
"model": config["llm_model"]
}
print("✓ 灰度策略配置完成")
---
b.发布流程
a.功能说明
规范化灰度发布流程,确保安全可控。分阶段逐步放量。实时监控指标,及时发现问题。建立回滚机制。标准化的发布流程保障线上稳定。
b.代码示例
---
from fastapi import FastAPI
from enum import Enum
from datetime import datetime, timedelta
app = FastAPI()
# 1. 发布阶段定义
class ReleaseStage(Enum):
"""发布阶段"""
CANARY = "canary" # 金丝雀(1-5%)
BETA = "beta" # Beta(5-20%)
STAGING = "staging" # 预发布(20-50%)
PRODUCTION = "production" # 全量(100%)
ROLLBACK = "rollback" # 回滚
# 2. 发布计划
class ReleasePlan:
"""发布计划"""
def __init__(self):
self.stages = [
{"stage": ReleaseStage.CANARY, "percentage": 5, "duration_minutes": 30},
{"stage": ReleaseStage.BETA, "percentage": 20, "duration_minutes": 60},
{"stage": ReleaseStage.STAGING, "percentage": 50, "duration_minutes": 120},
{"stage": ReleaseStage.PRODUCTION, "percentage": 100, "duration_minutes": 0}
]
self.current_stage_index = 0
self.stage_started_at = None
def start(self):
"""开始发布"""
self.current_stage_index = 0
self.stage_started_at = datetime.now()
return self.get_current_stage()
def get_current_stage(self):
"""获取当前阶段"""
if self.current_stage_index < len(self.stages):
return self.stages[self.current_stage_index]
return None
def advance_to_next_stage(self):
"""进入下一阶段"""
self.current_stage_index += 1
self.stage_started_at = datetime.now()
return self.get_current_stage()
def can_advance(self) -> bool:
"""是否可以进入下一阶段"""
stage = self.get_current_stage()
if not stage:
return False
# 检查时间
elapsed = (datetime.now() - self.stage_started_at).total_seconds() / 60
if elapsed < stage["duration_minutes"]:
return False
# 检查错误率
stats = get_canary_stats()
canary_error_rate = stats.get("canary", {}).get("error_rate", 0)
if canary_error_rate > 0.05: # 5%错误率阈值
return False
return True
release_plan = ReleasePlan()
# 3. 发布管理端点
@app.post("/release/start")
def start_release():
"""开始灰度发布"""
stage = release_plan.start()
return {
"message": "发布已启动",
"current_stage": stage["stage"].value,
"percentage": stage["percentage"],
"duration_minutes": stage["duration_minutes"]
}
@app.post("/release/advance")
def advance_release():
"""进入下一阶段"""
if not release_plan.can_advance():
return {
"error": "当前阶段未满足进入下一阶段的条件"
}
stage = release_plan.advance_to_next_stage()
if not stage:
return {
"message": "发布已完成",
"stage": "production"
}
# 更新灰度比例
canary_controller.set_percentage(stage["percentage"])
return {
"message": f"已进入{stage['stage'].value}阶段",
"percentage": stage["percentage"],
"duration_minutes": stage["duration_minutes"]
}
@app.get("/release/status")
def get_release_status():
"""获取发布状态"""
stage = release_plan.get_current_stage()
if not stage:
return {"status": "no_release"}
elapsed = (datetime.now() - release_plan.stage_started_at).total_seconds() / 60
return {
"stage": stage["stage"].value,
"percentage": stage["percentage"],
"elapsed_minutes": elapsed,
"remaining_minutes": max(0, stage["duration_minutes"] - elapsed),
"can_advance": release_plan.can_advance()
}
# 4. 自动推进
import asyncio
async def auto_advance_release():
"""自动推进发布"""
while True:
await asyncio.sleep(60) # 每分钟检查一次
if release_plan.can_advance():
stage = release_plan.advance_to_next_stage()
if stage:
canary_controller.set_percentage(stage["percentage"])
print(f"自动推进到{stage['stage'].value}阶段")
else:
print("发布完成")
break
@app.post("/release/start_auto")
async def start_auto_release():
"""启动自动推进发布"""
release_plan.start()
# 启动自动推进任务
asyncio.create_task(auto_advance_release())
return {"message": "自动发布已启动"}
# 5. 回滚流程
@app.post("/release/rollback")
def rollback_release():
"""回滚发布"""
# 禁用灰度,全部流量到稳定版
canary_controller.disable()
# 重置发布计划
release_plan.current_stage_index = 0
return {
"message": "已回滚到稳定版本",
"timestamp": datetime.now().isoformat()
}
# 6. 发布审批
class ReleaseApproval:
"""发布审批"""
def __init__(self):
self.approvals = {}
def request_approval(self, release_id: str, stage: str):
"""请求审批"""
self.approvals[release_id] = {
"stage": stage,
"status": "pending",
"requested_at": datetime.now().isoformat()
}
def approve(self, release_id: str, approver: str):
"""批准"""
if release_id in self.approvals:
self.approvals[release_id].update({
"status": "approved",
"approver": approver,
"approved_at": datetime.now().isoformat()
})
return True
return False
def is_approved(self, release_id: str) -> bool:
"""是否已批准"""
approval = self.approvals.get(release_id, {})
return approval.get("status") == "approved"
approval_mgr = ReleaseApproval()
@app.post("/release/request_approval")
def request_release_approval(release_id: str, stage: str):
"""请求发布审批"""
approval_mgr.request_approval(release_id, stage)
return {
"message": "审批请求已提交",
"release_id": release_id
}
@app.post("/release/approve")
def approve_release(release_id: str, approver: str):
"""批准发布"""
success = approval_mgr.approve(release_id, approver)
if success:
return {"message": "发布已批准"}
else:
return {"error": "未找到发布请求"}
# 7. 信创环境发布流程
@app.post("/xinchuang/release/start")
def start_xinchuang_release():
"""启动信创环境发布"""
plan = {
"environment": "xinchuang",
"stages": [
{
"stage": "canary",
"percentage": 5,
"duration": "30分钟",
"actions": [
"启动Qwen-14B灰度实例",
"配置达梦数据库连接池",
"监控错误率和响应时间"
]
},
{
"stage": "beta",
"percentage": 20,
"duration": "1小时",
"actions": [
"扩展灰度实例数量",
"验证国密加密功能",
"监控系统资源使用"
]
},
{
"stage": "production",
"percentage": 100,
"duration": "N/A",
"actions": [
"全量切换到新版本",
"停止旧版本实例",
"更新监控配置"
]
}
]
}
return {
"message": "信创环境发布已启动",
"plan": plan
}
# 8. 发布报告
@app.get("/release/report")
def get_release_report():
"""获取发布报告"""
stats = get_canary_stats()
status = get_release_status()
report = {
"release_status": status,
"version_stats": stats,
"health_checks": {
"stable_error_rate": stats.get("stable", {}).get("error_rate", 0),
"canary_error_rate": stats.get("canary", {}).get("error_rate", 0),
"acceptable": stats.get("canary", {}).get("error_rate", 0) < 0.05
},
"recommendations": []
}
# 生成建议
if report["health_checks"]["canary_error_rate"] > 0.05:
report["recommendations"].append("建议回滚:灰度版本错误率过高")
elif release_plan.can_advance():
report["recommendations"].append("可以进入下一阶段")
return report
print("✓ 灰度发布配置完成")
# ========== 总结 ==========
"""
## LangServe完整部署和优化总结
本文档涵盖了LangServe从开发到生产的完整生命周期:
### 1. 基础部署
- Docker容器化
- Nginx负载均衡
- 多实例高可用
### 2. 监控运维
- 集中式日志(ELK/Loki)
- Prometheus指标监控
- Grafana可视化仪表板
- 实时告警
### 3. 安全防护
- 认证授权(API Key/JWT/RBAC)
- 数据加密(AES/SM4)
- 防护措施(SQL注入、XSS、CSRF、DDoS)
- 信创国密适配
### 4. 性能优化
- 缓存策略(内存/Redis/智能缓存)
- 异步处理(asyncio/Celery)
- 连接池管理
### 5. 版本管理
- API版本控制(路径版本)
- 版本迁移(指南/工具/追踪)
- 配置管理(环境配置/热更新)
### 6. 灰度发布
- 流量分配(随机/用户/IP/白名单)
- 发布流程(阶段化/审批/自动推进)
- 自动回滚机制
### 信创环境特别适配
- 麒麟OS + 达梦DM8 + Ollama
- 国密算法(SM2/SM4)
- 内网安全加固
- 本地化部署方案
通过以上完整的体系,可以构建生产级的LangServe服务。
"""
---