53 lines
2.0 KiB
Python
53 lines
2.0 KiB
Python
# examples/mcp_modes/direct_tool_loop.py
|
|
import asyncio
|
|
from langchain_core.messages import HumanMessage, ToolMessage
|
|
from langgraph_qwen.chat_model import ChatQwenOpenAICompat
|
|
from langgraph_qwen.mcp import load_mcp_tools
|
|
|
|
PHASE1_SERVERS = {
|
|
"math": {"command":"python","args":["/abs/path/to/math_server.py"],"transport":"stdio"}
|
|
}
|
|
PHASE2_SERVERS = {
|
|
"weather": {"url":"http://127.0.0.1:8000/mcp/","transport":"streamable_http"}
|
|
}
|
|
|
|
async def main(max_steps: int = 8):
|
|
# ★ 比如不同阶段加载不同工具集
|
|
tools_phase1 = await load_mcp_tools(servers=PHASE1_SERVERS)
|
|
tools_phase2 = await load_mcp_tools(servers=PHASE2_SERVERS)
|
|
|
|
model = ChatQwenOpenAICompat(temperature=0).bind_tools(
|
|
tools_phase1 + tools_phase2
|
|
).bind(tool_choice="auto")
|
|
|
|
tool_map = {t.name: t for t in (tools_phase1 + tools_phase2)}
|
|
msgs = [HumanMessage(content="先用数学工具算 12*(3+5),再查北京天气,最后给一句总结。")]
|
|
|
|
for _ in range(max_steps):
|
|
ai = await model.ainvoke(msgs)
|
|
msgs.append(ai)
|
|
calls = getattr(ai, "tool_calls", []) or ai.additional_kwargs.get("tool_calls", [])
|
|
if not calls:
|
|
break
|
|
for call in calls:
|
|
name, args, call_id = call.get("name"), call.get("args", {}), call.get("id") or ""
|
|
tool = tool_map.get(name)
|
|
if not tool:
|
|
msgs.append(ToolMessage(tool_call_id=call_id, content=f"Unknown tool: {name}"))
|
|
continue
|
|
try:
|
|
if hasattr(tool, "ainvoke"):
|
|
out = await tool.ainvoke(args)
|
|
else:
|
|
out = tool.invoke(args)
|
|
msgs.append(ToolMessage(tool_call_id=call_id, content=str(out)))
|
|
except Exception as e:
|
|
msgs.append(ToolMessage(tool_call_id=call_id, content=f"Error: {e}"))
|
|
|
|
final = await model.ainvoke(msgs)
|
|
print("=== Final ===")
|
|
print(final.content)
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|