177 lines
5.8 KiB
Python
177 lines
5.8 KiB
Python
# langgraph_qwen/mcp.py
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import json
|
||
import asyncio
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from langgraph.prebuilt import create_react_agent
|
||
from .chat_model import ChatQwenOpenAICompat
|
||
|
||
__all__ = [
|
||
"load_mcp_tools",
|
||
"create_qwen_agent_with_mcp_async",
|
||
"create_qwen_agent_with_mcp",
|
||
"resolve_servers_config",
|
||
]
|
||
|
||
def _env(name: str, default: str = "") -> str:
|
||
return os.getenv(name) or default
|
||
|
||
# ---------- 配置加载 & 合并 ----------
|
||
|
||
def _read_config_file(path: str) -> Dict[str, Dict[str, Any]]:
|
||
p = Path(path)
|
||
if not p.exists():
|
||
raise FileNotFoundError(f"MCP config file not found: {path}")
|
||
|
||
ext = p.suffix.lower()
|
||
if ext in {".json", ""}:
|
||
data = json.loads(p.read_text(encoding="utf-8"))
|
||
elif ext in {".yml", ".yaml"}:
|
||
try:
|
||
import yaml # type: ignore
|
||
except Exception as e:
|
||
raise RuntimeError("需要解析 YAML:请安装 `pip install pyyaml`") from e
|
||
data = yaml.safe_load(p.read_text(encoding="utf-8"))
|
||
else:
|
||
raise ValueError(f"Unsupported config extension: {ext}")
|
||
|
||
# 支持两种结构:
|
||
# 1) {"servers": { ... }}
|
||
# 2) { ... } 直接就是 servers 映射
|
||
if isinstance(data, dict) and "servers" in data and isinstance(data["servers"], dict):
|
||
return data["servers"]
|
||
if isinstance(data, dict):
|
||
return data
|
||
raise ValueError("Invalid config structure: expect a dict (optionally under key 'servers').")
|
||
|
||
def _servers_from_env_json() -> Optional[Dict[str, Dict[str, Any]]]:
|
||
raw = os.getenv("MCP_SERVERS_JSON")
|
||
if not raw:
|
||
return None
|
||
try:
|
||
data = json.loads(raw)
|
||
if isinstance(data, dict) and data:
|
||
return data # 直接就是 servers 映射
|
||
except Exception:
|
||
pass
|
||
return None
|
||
|
||
def _fallback_servers() -> Dict[str, Dict[str, Any]]:
|
||
# 最后兜底:单一 HTTP 测试服务(本地 FastAPI/fastmcp-http)
|
||
return {
|
||
"weather": {
|
||
"url": _env("WEATHER_MCP_URL", "http://127.0.0.1:8000/mcp/"),
|
||
"transport": _env("WEATHER_TRANSPORT", "streamable_http"),
|
||
}
|
||
}
|
||
|
||
def resolve_servers_config(
|
||
servers: Optional[Dict[str, Dict[str, Any]]] = None,
|
||
config_path: Optional[str] = None,
|
||
) -> Dict[str, Dict[str, Any]]:
|
||
"""
|
||
统一解析 servers 配置的入口,优先级(后者覆盖前者):
|
||
1) 文件(config_path 或 MCP_CONFIG_PATH)
|
||
2) 环境变量 MCP_SERVERS_JSON
|
||
3) 代码传入的 servers(作为最终覆盖)
|
||
4) 兜底本地 weather HTTP 服务器
|
||
"""
|
||
merged: Dict[str, Dict[str, Any]] = {}
|
||
|
||
# 1) 文件
|
||
path = config_path or os.getenv("MCP_CONFIG_PATH")
|
||
if path:
|
||
try:
|
||
merged.update(_read_config_file(path))
|
||
except Exception as e:
|
||
raise RuntimeError(f"读取 MCP 配置文件失败:{e}")
|
||
|
||
# 2) 环境变量(JSON 字符串)
|
||
env_servers = _servers_from_env_json()
|
||
if env_servers:
|
||
merged.update(env_servers)
|
||
|
||
# 3) 代码传入(最后覆盖)
|
||
if servers:
|
||
merged.update(servers)
|
||
|
||
# 4) 兜底
|
||
if not merged:
|
||
merged = _fallback_servers()
|
||
return merged
|
||
|
||
# ---------- MCP 工具加载 ----------
|
||
|
||
async def load_mcp_tools(
|
||
servers: Optional[Dict[str, Dict[str, Any]]] = None,
|
||
config_path: Optional[str] = None,
|
||
) -> List[Any]:
|
||
"""
|
||
异步加载 MCP 工具(支持 HTTP streamable 与 本地 stdio/npx)。
|
||
依赖:langchain-mcp-adapters
|
||
"""
|
||
try:
|
||
from langchain_mcp_adapters.client import MultiServerMCPClient # type: ignore
|
||
except Exception as e:
|
||
raise RuntimeError("请安装:`uv pip install -e '.[mcp-adapters]'` 或 `pip install langchain-mcp_adapters`") from e
|
||
|
||
resolved_servers = resolve_servers_config(servers=servers, config_path=config_path)
|
||
client = MultiServerMCPClient(resolved_servers)
|
||
|
||
tools = await client.get_tools()
|
||
|
||
# 关闭会话(兼容不同版本)
|
||
try:
|
||
if hasattr(client, "close") and callable(getattr(client, "close")):
|
||
await client.close() # type: ignore
|
||
elif hasattr(client, "close_all_sessions") and callable(getattr(client, "close_all_sessions")):
|
||
await client.close_all_sessions() # type: ignore
|
||
except Exception:
|
||
pass
|
||
return tools
|
||
|
||
# ---------- Agent 工厂 ----------
|
||
|
||
async def create_qwen_agent_with_mcp_async(
|
||
*,
|
||
servers: Optional[Dict[str, Dict[str, Any]]] = None,
|
||
config_path: Optional[str] = None,
|
||
tool_choice: str = "auto",
|
||
model: Optional[ChatQwenOpenAICompat] = None,
|
||
):
|
||
"""
|
||
异步工厂:加载 MCP 工具 -> 绑定到 Qwen 模型 -> 创建 LangGraph ReAct agent
|
||
"""
|
||
tools = await load_mcp_tools(servers=servers, config_path=config_path)
|
||
model = (model or ChatQwenOpenAICompat(temperature=0)).bind_tools(tools).bind(tool_choice=tool_choice)
|
||
return create_react_agent(model, tools)
|
||
|
||
def _run_coro_sync(coro):
|
||
try:
|
||
asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
return asyncio.run(coro)
|
||
raise RuntimeError("检测到事件循环,改用:`await create_qwen_agent_with_mcp_async(...)`")
|
||
|
||
def create_qwen_agent_with_mcp(
|
||
*,
|
||
model: Optional[ChatQwenOpenAICompat] = None,
|
||
servers: Optional[Dict[str, Dict[str, Any]]] = None,
|
||
config_path: Optional[str] = None,
|
||
tool_choice: str = "auto",
|
||
):
|
||
"""
|
||
同步工厂:内部临时事件循环拉工具,再创建 ReAct agent。
|
||
用于 quickstart/脚本场景;在异步框架里请用 create_qwen_agent_with_mcp_async。
|
||
"""
|
||
model = model or ChatQwenOpenAICompat(temperature=0)
|
||
return _run_coro_sync(
|
||
create_qwen_agent_with_mcp_async(
|
||
servers=servers, config_path=config_path, tool_choice=tool_choice, model=model
|
||
)
|
||
)
|