forked from lingyuzeng/agent
langgraph adapter
This commit is contained in:
29
langgraph_qwen/__init__.py
Normal file
29
langgraph_qwen/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
Lightweight adapter to use Qwen3-Coder(-Flash) models with LangGraph via
|
||||||
|
LangChain ChatModels. Focuses on an OpenAI-compatible path and a clean
|
||||||
|
factory API, plus helpers for tool binding and streaming examples.
|
||||||
|
|
||||||
|
Usage quickstart:
|
||||||
|
|
||||||
|
from langgraph_qwen.factory import get_qwen_chat_model
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
|
||||||
|
model = get_qwen_chat_model()
|
||||||
|
agent = create_react_agent(model, tools=[...])
|
||||||
|
|
||||||
|
Environment variables:
|
||||||
|
- QWEN_API_KEY: API key for OpenAI-compatible endpoint (or DashScope)
|
||||||
|
- QWEN_BASE_URL: Base URL (e.g. https://dashscope-intl.aliyuncs.com/compatible-mode/v1)
|
||||||
|
- QWEN_MODEL: Model name (e.g. qwen3-coder-flash)
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_qwen_chat_model",
|
||||||
|
"bind_qwen_tools",
|
||||||
|
"ChatQwenOpenAICompat",
|
||||||
|
"create_qwen_react_agent",
|
||||||
|
]
|
||||||
|
|
||||||
|
from .factory import get_qwen_chat_model, bind_qwen_tools
|
||||||
|
from .chat_model import ChatQwenOpenAICompat
|
||||||
|
from .prebuilt import create_qwen_react_agent
|
||||||
414
langgraph_qwen/chat_model.py
Normal file
414
langgraph_qwen/chat_model.py
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from .utils import ensure_env_loaded
|
||||||
|
try:
|
||||||
|
# pydantic v2
|
||||||
|
from pydantic import PrivateAttr
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
# pydantic v1 fallback
|
||||||
|
from pydantic.fields import PrivateAttr # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def _env(name: str, default: Optional[str] = None) -> Optional[str]:
|
||||||
|
v = os.getenv(name)
|
||||||
|
return v if v else default
|
||||||
|
|
||||||
|
|
||||||
|
def _coalesce(*vals):
|
||||||
|
for v in vals:
|
||||||
|
if v is not None:
|
||||||
|
return v
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _to_openai_messages(messages: List["BaseMessage"]) -> List[Dict[str, Any]]:
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
out: List[Dict[str, Any]] = []
|
||||||
|
for m in messages:
|
||||||
|
if isinstance(m, SystemMessage):
|
||||||
|
out.append({"role": "system", "content": m.content})
|
||||||
|
elif isinstance(m, HumanMessage):
|
||||||
|
out.append({"role": "user", "content": m.content})
|
||||||
|
elif isinstance(m, AIMessage):
|
||||||
|
# If previous AI produced tool calls, pass them through
|
||||||
|
entry: Dict[str, Any] = {"role": "assistant", "content": m.content}
|
||||||
|
tcs = getattr(m, "tool_calls", None) or m.additional_kwargs.get("tool_calls") if hasattr(m, "additional_kwargs") else None
|
||||||
|
if tcs:
|
||||||
|
# Map to OpenAI's expected shape
|
||||||
|
entry["tool_calls"] = [
|
||||||
|
{
|
||||||
|
"id": tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": (tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", "")) or "",
|
||||||
|
"arguments": json.dumps(tc.get("args") if isinstance(tc, dict) else getattr(tc, "args", {})),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in (tcs or [])
|
||||||
|
]
|
||||||
|
out.append(entry)
|
||||||
|
elif isinstance(m, ToolMessage):
|
||||||
|
# OpenAI-compatible "tool" role with name + id must be provided
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": m.content,
|
||||||
|
"tool_call_id": m.tool_call_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(m, FunctionMessage):
|
||||||
|
# Legacy function messages
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": m.content,
|
||||||
|
"tool_call_id": m.name or "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback
|
||||||
|
out.append({"role": getattr(m, "role", "user"), "content": getattr(m, "content", str(m))})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_tool(tool: Any) -> Dict[str, Any]:
|
||||||
|
"""Best-effort conversion to OpenAI tool schema."""
|
||||||
|
# Prefer built-in conversion if available
|
||||||
|
if hasattr(tool, "to_openai_tool"):
|
||||||
|
try:
|
||||||
|
return tool.to_openai_tool() # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
name = getattr(tool, "name", None) or getattr(tool, "__name__", "tool")
|
||||||
|
description = getattr(tool, "description", "")
|
||||||
|
# Attempt to get a JSON schema
|
||||||
|
params: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||||
|
schema = None
|
||||||
|
if hasattr(tool, "args_schema") and getattr(tool, "args_schema") is not None:
|
||||||
|
try:
|
||||||
|
schema = tool.args_schema.schema() # pydantic v1
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
schema = tool.args_schema.model_json_schema() # pydantic v2
|
||||||
|
except Exception:
|
||||||
|
schema = None
|
||||||
|
if schema is None and hasattr(tool, "args"):
|
||||||
|
try:
|
||||||
|
schema = tool.args
|
||||||
|
except Exception:
|
||||||
|
schema = None
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
params = schema
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"parameters": params,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_tools(tools: Optional[List[Any]]) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
if not tools:
|
||||||
|
return None
|
||||||
|
return [_convert_tool(t) for t in tools]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tool_calls(msg: Dict[str, Any]) -> Tuple[str, List[Dict[str, Any]]]:
|
||||||
|
content = _coalesce(msg.get("content"), "")
|
||||||
|
tool_calls_raw = msg.get("tool_calls") or []
|
||||||
|
parsed: List[Dict[str, Any]] = []
|
||||||
|
for tc in tool_calls_raw:
|
||||||
|
fn = (tc or {}).get("function", {})
|
||||||
|
name = fn.get("name") or ""
|
||||||
|
args_str = fn.get("arguments") or "{}"
|
||||||
|
args = {}
|
||||||
|
if isinstance(args_str, str):
|
||||||
|
try:
|
||||||
|
args = json.loads(args_str or "{}")
|
||||||
|
except Exception:
|
||||||
|
# keep raw if malformed; downstream may repair
|
||||||
|
args = {"$raw": args_str}
|
||||||
|
elif isinstance(args_str, dict):
|
||||||
|
args = args_str
|
||||||
|
parsed.append({
|
||||||
|
"id": tc.get("id"),
|
||||||
|
"name": name,
|
||||||
|
"args": args,
|
||||||
|
})
|
||||||
|
return content, parsed
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ToolCallBuilder:
|
||||||
|
id: Optional[str] = None
|
||||||
|
name: str = ""
|
||||||
|
args_buf: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_final(self) -> Dict[str, Any]:
|
||||||
|
args_str = "".join(self.args_buf) if self.args_buf else "{}"
|
||||||
|
try:
|
||||||
|
args = json.loads(args_str or "{}")
|
||||||
|
except Exception:
|
||||||
|
args = {"$raw": args_str}
|
||||||
|
return {"id": self.id, "name": self.name, "args": args}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatQwenOpenAICompat(BaseChatModel):
|
||||||
|
"""
|
||||||
|
A minimal LangChain BaseChatModel-compatible class that calls an
|
||||||
|
OpenAI-compatible /chat/completions endpoint and assembles tool calls.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Implements _generate for non-stream and _stream for SSE streaming.
|
||||||
|
- Stores bound tools and tool_choice to send with requests.
|
||||||
|
- Avoids heavy dependencies; uses httpx at runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Private state to avoid pydantic field validation
|
||||||
|
_model: str = PrivateAttr()
|
||||||
|
_api_key: str = PrivateAttr()
|
||||||
|
_base_url: str = PrivateAttr()
|
||||||
|
_temperature: Optional[float] = PrivateAttr(default=None)
|
||||||
|
_max_tokens: Optional[int] = PrivateAttr(default=None)
|
||||||
|
_extra_body: Dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||||
|
_timeout: float = PrivateAttr(default=60.0)
|
||||||
|
_tools: Optional[List[Dict[str, Any]]] = PrivateAttr(default=None)
|
||||||
|
_tool_choice: Optional[str] = PrivateAttr(default=None)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# Load .env variables once if available
|
||||||
|
ensure_env_loaded()
|
||||||
|
# Store config in private attrs
|
||||||
|
self._model = model or _env("QWEN_MODEL", "qwen3-coder-flash")
|
||||||
|
self._api_key = (
|
||||||
|
api_key
|
||||||
|
or _env("QWEN_API_KEY")
|
||||||
|
or _env("GPUSTACK_API_KEY")
|
||||||
|
or _env("OPENAI_API_KEY")
|
||||||
|
or _env("DASHSCOPE_API_KEY")
|
||||||
|
)
|
||||||
|
if not self._api_key:
|
||||||
|
raise ValueError("Missing API key: set QWEN_API_KEY/OPENAI_API_KEY/DASHSCOPE_API_KEY")
|
||||||
|
self._base_url = base_url or _env("QWEN_BASE_URL") or _env("OPENAI_BASE_URL")
|
||||||
|
if not self._base_url:
|
||||||
|
# Default to DashScope intl OpenAI-compatible
|
||||||
|
self._base_url = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
|
||||||
|
# Allow env override for timeout
|
||||||
|
timeout_env = _env("QWEN_TIMEOUT")
|
||||||
|
if timeout_env:
|
||||||
|
try:
|
||||||
|
timeout = float(timeout_env)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._temperature = temperature
|
||||||
|
self._max_tokens = max_tokens
|
||||||
|
self._extra_body = extra_body or {}
|
||||||
|
self._timeout = timeout
|
||||||
|
# tools/tool_choice already defaulted via PrivateAttr declarations
|
||||||
|
|
||||||
|
# Expose LangChain binding-style API
|
||||||
|
def bind_tools(self, tools: List[Any]) -> "ChatQwenOpenAICompat":
|
||||||
|
clone = self._copy()
|
||||||
|
clone._tools = _convert_tools(tools)
|
||||||
|
return clone
|
||||||
|
|
||||||
|
def bind(self, **kwargs: Any) -> "ChatQwenOpenAICompat":
|
||||||
|
clone = self._copy()
|
||||||
|
# Allow passing tool_choice or extra_body overrides
|
||||||
|
if "tool_choice" in kwargs:
|
||||||
|
clone._tool_choice = kwargs["tool_choice"]
|
||||||
|
if "extra_body" in kwargs and isinstance(kwargs["extra_body"], dict):
|
||||||
|
eb = dict(clone._extra_body)
|
||||||
|
eb.update(kwargs["extra_body"]) # type: ignore
|
||||||
|
clone._extra_body = eb
|
||||||
|
return clone
|
||||||
|
|
||||||
|
# LangChain BaseChatModel required interface
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str: # type: ignore[override]
|
||||||
|
return "qwen-openai-compat"
|
||||||
|
|
||||||
|
def _copy(self) -> "ChatQwenOpenAICompat":
|
||||||
|
c = ChatQwenOpenAICompat(
|
||||||
|
model=self._model,
|
||||||
|
api_key=self._api_key,
|
||||||
|
base_url=self._base_url,
|
||||||
|
temperature=self._temperature,
|
||||||
|
max_tokens=self._max_tokens,
|
||||||
|
extra_body=dict(self._extra_body),
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
c._tools = self._tools[:] if self._tools else None
|
||||||
|
c._tool_choice = self._tool_choice
|
||||||
|
return c
|
||||||
|
|
||||||
|
# -- Internal HTTP helpers --
|
||||||
|
def _request(self, payload: Dict[str, Any], *, stream: bool = False) -> Any:
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("httpx is required for ChatQwenOpenAICompat. Install via `pip install httpx`. ") from e
|
||||||
|
|
||||||
|
# Auth header customization (for non-standard gateways)
|
||||||
|
auth_header = _env("QWEN_AUTH_HEADER", "Authorization")
|
||||||
|
auth_scheme = _env("QWEN_AUTH_SCHEME", "Bearer")
|
||||||
|
auth_value = f"{auth_scheme} {self._api_key}" if auth_scheme else self._api_key
|
||||||
|
headers = {auth_header: auth_value, "Content-Type": "application/json"}
|
||||||
|
base = self._base_url.rstrip("/")
|
||||||
|
# Accept either a root base_url (ending with /v1) or a full endpoint that already includes /chat/completions
|
||||||
|
if "chat/completions" in base:
|
||||||
|
url = base
|
||||||
|
else:
|
||||||
|
url = base + "/chat/completions"
|
||||||
|
# Proxy behavior: honor env by default; allow disabling via QWEN_HTTP_TRUST_ENV=0
|
||||||
|
trust_env = _env("QWEN_HTTP_TRUST_ENV", "1") != "0"
|
||||||
|
client = httpx.Client(timeout=self._timeout, trust_env=trust_env)
|
||||||
|
if _env("QWEN_DEBUG") == "1":
|
||||||
|
import sys as _sys
|
||||||
|
print(f"[qwen] POST {url} timeout={self._timeout} trust_env={trust_env}", file=_sys.stderr)
|
||||||
|
if stream:
|
||||||
|
return client.stream("POST", url, headers=headers, json=payload)
|
||||||
|
resp = client.post(url, headers=headers, json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
def _build_payload(self, messages: List["BaseMessage"], **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"model": self._model,
|
||||||
|
"messages": _to_openai_messages(messages),
|
||||||
|
}
|
||||||
|
if self._temperature is not None:
|
||||||
|
payload["temperature"] = self._temperature
|
||||||
|
if self._max_tokens is not None:
|
||||||
|
payload["max_tokens"] = self._max_tokens
|
||||||
|
if self._tools:
|
||||||
|
payload["tools"] = self._tools
|
||||||
|
if self._tool_choice is not None:
|
||||||
|
payload["tool_choice"] = self._tool_choice
|
||||||
|
if self._extra_body:
|
||||||
|
payload.update(self._extra_body)
|
||||||
|
if kwargs:
|
||||||
|
payload.update(kwargs)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
# -- Non-streaming generate --
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List["BaseMessage"],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[Any] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "ChatResult": # type: ignore[name-defined]
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
payload = self._build_payload(messages, **kwargs)
|
||||||
|
data = self._request(payload, stream=False)
|
||||||
|
choice = (data.get("choices") or [{}])[0]
|
||||||
|
msg = choice.get("message", {})
|
||||||
|
content, tool_calls = _parse_tool_calls(msg)
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
|
||||||
|
extra_kwargs = {
|
||||||
|
"raw": data,
|
||||||
|
"reasoning_content": msg.get("reasoning_content"),
|
||||||
|
}
|
||||||
|
msg_kwargs: Dict[str, Any] = {"content": content or "", "additional_kwargs": extra_kwargs}
|
||||||
|
if tool_calls:
|
||||||
|
msg_kwargs["tool_calls"] = tool_calls
|
||||||
|
ai = AIMessage(**msg_kwargs)
|
||||||
|
gen = ChatGeneration(message=ai)
|
||||||
|
return ChatResult(generations=[gen], llm_output={"usage": usage, "model": data.get("model")})
|
||||||
|
|
||||||
|
# -- Streaming generate --
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List["BaseMessage"],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[Any] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator["ChatGenerationChunk"]: # type: ignore[name-defined]
|
||||||
|
from langchain_core.outputs import ChatGenerationChunk
|
||||||
|
from langchain_core.messages import AIMessageChunk
|
||||||
|
|
||||||
|
payload = self._build_payload(messages, stream=True, **kwargs)
|
||||||
|
with self._request(payload, stream=True) as r: # type: ignore[attr-defined]
|
||||||
|
# SSE stream; lines like: data: {...}\n\n; end with data: [DONE]
|
||||||
|
text_buf: List[str] = []
|
||||||
|
tc_builders: List[_ToolCallBuilder] = []
|
||||||
|
for line in r.iter_lines(): # type: ignore[attr-defined]
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if isinstance(line, (bytes, bytearray)):
|
||||||
|
try:
|
||||||
|
line = line.decode("utf-8", errors="ignore")
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
data = line[len("data:"):].strip()
|
||||||
|
if data == "[DONE]":
|
||||||
|
# Finalize: emit a chunk with final tool calls if any
|
||||||
|
if tc_builders or text_buf:
|
||||||
|
extra: Dict[str, Any] = {}
|
||||||
|
if tc_builders:
|
||||||
|
extra["tool_calls"] = [b.to_final() for b in tc_builders]
|
||||||
|
final = AIMessageChunk(
|
||||||
|
content="".join(text_buf) if text_buf else "",
|
||||||
|
additional_kwargs=extra if extra else None,
|
||||||
|
)
|
||||||
|
yield ChatGenerationChunk(message=final)
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
obj = json.loads(data)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
choices = obj.get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
continue
|
||||||
|
delta = (choices[0].get("delta") or {})
|
||||||
|
# Text deltas
|
||||||
|
dtext = delta.get("content")
|
||||||
|
if dtext:
|
||||||
|
text_buf.append(dtext)
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=dtext))
|
||||||
|
# Tool call deltas
|
||||||
|
dtools = delta.get("tool_calls") or []
|
||||||
|
for i, part in enumerate(dtools):
|
||||||
|
# Ensure list length
|
||||||
|
while len(tc_builders) <= i:
|
||||||
|
tc_builders.append(_ToolCallBuilder())
|
||||||
|
b = tc_builders[i]
|
||||||
|
if "id" in part:
|
||||||
|
b.id = part.get("id") or b.id
|
||||||
|
fn = part.get("function") or {}
|
||||||
|
if fn.get("name"):
|
||||||
|
b.name = fn["name"]
|
||||||
|
if fn.get("arguments"):
|
||||||
|
b.args_buf.append(fn["arguments"]) # partial JSON string
|
||||||
135
langgraph_qwen/factory.py
Normal file
135
langgraph_qwen/factory.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from .utils import ensure_env_loaded
|
||||||
|
|
||||||
|
|
||||||
|
def _env(var: str, default: Optional[str] = None) -> Optional[str]:
|
||||||
|
v = os.getenv(var)
|
||||||
|
return v if v is not None and v != "" else default
|
||||||
|
|
||||||
|
|
||||||
|
def get_qwen_chat_model(
|
||||||
|
*,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
extra: Optional[Dict[str, Any]] = None,
|
||||||
|
prefer: str = "openai_compat",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return a LangChain ChatModel configured for Qwen3-Coder(-Flash).
|
||||||
|
|
||||||
|
Tries the OpenAI-compatible client (langchain_openai.ChatOpenAI) by default.
|
||||||
|
If unavailable and prefer != 'openai_compat', will try DashScope
|
||||||
|
(langchain_community.chat_models.tongyi.ChatTongyi).
|
||||||
|
|
||||||
|
This function does not import heavy deps at module import time; imports are
|
||||||
|
local to keep repo light.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = model or _env("QWEN_MODEL", "qwen3-coder-flash")
|
||||||
|
# Load .env variables once if available
|
||||||
|
ensure_env_loaded()
|
||||||
|
api_key = (
|
||||||
|
api_key
|
||||||
|
or _env("QWEN_API_KEY")
|
||||||
|
or _env("GPUSTACK_API_KEY")
|
||||||
|
or _env("OPENAI_API_KEY")
|
||||||
|
or _env("DASHSCOPE_API_KEY")
|
||||||
|
)
|
||||||
|
base_url = base_url or _env("QWEN_BASE_URL") or _env("OPENAI_BASE_URL")
|
||||||
|
|
||||||
|
# Common generation params
|
||||||
|
gen_kwargs: Dict[str, Any] = {}
|
||||||
|
if temperature is not None:
|
||||||
|
gen_kwargs["temperature"] = temperature
|
||||||
|
if max_tokens is not None:
|
||||||
|
gen_kwargs["max_tokens"] = max_tokens
|
||||||
|
if extra:
|
||||||
|
gen_kwargs.update(extra)
|
||||||
|
|
||||||
|
err: Optional[Exception] = None
|
||||||
|
|
||||||
|
if prefer == "custom":
|
||||||
|
try:
|
||||||
|
from .chat_model import ChatQwenOpenAICompat
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
api_key = (
|
||||||
|
_env("QWEN_API_KEY")
|
||||||
|
or _env("GPUSTACK_API_KEY")
|
||||||
|
or _env("OPENAI_API_KEY")
|
||||||
|
or _env("DASHSCOPE_API_KEY")
|
||||||
|
)
|
||||||
|
chat = ChatQwenOpenAICompat(
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
**gen_kwargs,
|
||||||
|
)
|
||||||
|
return chat
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Custom adapter init failed: {e}")
|
||||||
|
|
||||||
|
if prefer == "openai_compat":
|
||||||
|
try:
|
||||||
|
from langchain_openai import ChatOpenAI # type: ignore
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError("QWEN_API_KEY/OPENAI_API_KEY/DASHSCOPE_API_KEY is required")
|
||||||
|
|
||||||
|
# ChatOpenAI supports base_url override for OpenAI-compatible servers
|
||||||
|
chat = ChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
**gen_kwargs,
|
||||||
|
)
|
||||||
|
return chat
|
||||||
|
except Exception as e:
|
||||||
|
err = e
|
||||||
|
|
||||||
|
# Fallback: DashScope (Tongyi) path via community provider
|
||||||
|
try:
|
||||||
|
from langchain_community.chat_models.tongyi import ChatTongyi # type: ignore
|
||||||
|
|
||||||
|
# ChatTongyi uses DASHSCOPE_API_KEY from env by default; api_key arg optional
|
||||||
|
chat = ChatTongyi(
|
||||||
|
model=model,
|
||||||
|
dashscope_api_key=api_key,
|
||||||
|
**gen_kwargs,
|
||||||
|
)
|
||||||
|
return chat
|
||||||
|
except Exception as e2:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to create Qwen ChatModel. Tried OpenAI-compatible and DashScope paths.\n"
|
||||||
|
f"Prefer='" + prefer + "' error: " + (str(err) if err else "None") + "\n" +
|
||||||
|
"DashScope path error: " + str(e2)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bind_qwen_tools(
|
||||||
|
model, tools: List[Any], *, tool_choice: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Bind tools to a LangChain ChatModel with Qwen-friendly defaults.
|
||||||
|
|
||||||
|
- tools: list of BaseTool or functions decorated with @tool
|
||||||
|
- tool_choice: 'auto' | 'required' | tool-name | None
|
||||||
|
Returns: a model instance with bound tools.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Many LangChain chat models implement .bind_tools(...)
|
||||||
|
# We keep this thin and rely on LangChain's native tool schema handling.
|
||||||
|
bound = model.bind_tools(tools)
|
||||||
|
if tool_choice is not None:
|
||||||
|
# Some providers accept tool_choice via .bind; others via gen kwargs
|
||||||
|
# We try to set it via .bind when available.
|
||||||
|
try:
|
||||||
|
bound = bound.bind(tool_choice=tool_choice)
|
||||||
|
except Exception:
|
||||||
|
# If provider doesn't accept bind(tool_choice=..), attach to kwargs
|
||||||
|
bound = bound.bind(extra_body={"tool_choice": tool_choice})
|
||||||
|
return bound
|
||||||
36
langgraph_qwen/prebuilt.py
Normal file
36
langgraph_qwen/prebuilt.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import Any, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
from .factory import get_qwen_chat_model, bind_qwen_tools
|
||||||
|
from .validators import validate_tool_names, ToolValidationError
|
||||||
|
|
||||||
|
|
||||||
|
def create_qwen_react_agent(
|
||||||
|
tools: List[Any],
|
||||||
|
*,
|
||||||
|
model: Optional[Any] = None,
|
||||||
|
prefer: str = "custom",
|
||||||
|
tool_choice: Optional[str] = "auto",
|
||||||
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
enforce_tool_name_check: bool = True,
|
||||||
|
**agent_kwargs: Any,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Sugar around langgraph.prebuilt.create_react_agent configured for Qwen.
|
||||||
|
|
||||||
|
- If `model` is None, constructs one via get_qwen_chat_model(prefer=..., **model_kwargs).
|
||||||
|
- Binds tools and applies `tool_choice`.
|
||||||
|
- Returns a LangGraph ReAct agent ready to invoke/stream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langgraph.prebuilt import create_react_agent as _create
|
||||||
|
|
||||||
|
model_kwargs = model_kwargs or {}
|
||||||
|
|
||||||
|
# Strictly validate tool names before binding, if enabled
|
||||||
|
if enforce_tool_name_check:
|
||||||
|
validate_tool_names(tools, strict=True)
|
||||||
|
if model is None:
|
||||||
|
model = get_qwen_chat_model(prefer=prefer, **model_kwargs)
|
||||||
|
|
||||||
|
model = bind_qwen_tools(model, tools, tool_choice=tool_choice)
|
||||||
|
return _create(model, tools, **agent_kwargs)
|
||||||
22
langgraph_qwen/utils.py
Normal file
22
langgraph_qwen/utils.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
_DOTENV_LOADED = False
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_env_loaded() -> None:
|
||||||
|
global _DOTENV_LOADED
|
||||||
|
if _DOTENV_LOADED:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
# Lazy import to avoid hard failure if not installed
|
||||||
|
from dotenv import load_dotenv, find_dotenv # type: ignore
|
||||||
|
|
||||||
|
path = find_dotenv(usecwd=True)
|
||||||
|
if path:
|
||||||
|
load_dotenv(dotenv_path=path, override=False)
|
||||||
|
else:
|
||||||
|
# Fall back to default discovery in CWD
|
||||||
|
load_dotenv(override=False)
|
||||||
|
except Exception:
|
||||||
|
# Silently ignore if python-dotenv is not available or errors out
|
||||||
|
pass
|
||||||
|
_DOTENV_LOADED = True
|
||||||
|
|
||||||
67
langgraph_qwen/validators.py
Normal file
67
langgraph_qwen/validators.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import re
|
||||||
|
import hashlib
|
||||||
|
from typing import Iterable, List, Tuple, Any
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolValidationError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_tool_name(name: str, max_len: int = 64) -> str:
|
||||||
|
safe = re.sub(r"[^a-zA-Z0-9_-]", "_", name or "tool")
|
||||||
|
if len(safe) > max_len:
|
||||||
|
suffix = hashlib.sha1(safe.encode()).hexdigest()[:8]
|
||||||
|
safe = safe[: max_len - 9] + "_" + suffix
|
||||||
|
return safe or "tool"
|
||||||
|
|
||||||
|
|
||||||
|
def _find_illegal_chars(name: str) -> List[str]:
|
||||||
|
return sorted(list(set(re.findall(r"[^a-zA-Z0-9_-]", name))))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_name(obj: Any) -> str:
|
||||||
|
return getattr(obj, "name", None) or getattr(obj, "tool_name", None) or getattr(obj, "__name__", "tool")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_tool_names(
|
||||||
|
tools: Iterable[Any], *, strict: bool = True, max_len: int = 64
|
||||||
|
) -> Tuple[bool, List[str]]:
|
||||||
|
"""
|
||||||
|
Validate that tool names conform to OpenAI-compatible constraints:
|
||||||
|
- Allowed chars: a-z, A-Z, 0-9, underscore, hyphen
|
||||||
|
- Max length: 64
|
||||||
|
|
||||||
|
Returns (ok, messages). If strict=True and invalid tools exist,
|
||||||
|
raises ToolValidationError with a detailed report.
|
||||||
|
"""
|
||||||
|
messages: List[str] = []
|
||||||
|
errors: List[str] = []
|
||||||
|
for idx, t in enumerate(tools):
|
||||||
|
name = _get_name(t)
|
||||||
|
if TOOL_NAME_PATTERN.match(name or ""):
|
||||||
|
continue
|
||||||
|
parts: List[str] = [f"第{idx}个工具名称不合法: '{name}'"]
|
||||||
|
if not name:
|
||||||
|
parts.append("原因: 为空")
|
||||||
|
else:
|
||||||
|
illegal = _find_illegal_chars(name)
|
||||||
|
if illegal:
|
||||||
|
parts.append(f"包含非法字符: {''.join(illegal)} (仅允许[a-zA-Z0-9_-])")
|
||||||
|
if len(name) > max_len:
|
||||||
|
parts.append(f"长度超限: {len(name)} > {max_len}")
|
||||||
|
suggestion = sanitize_tool_name(name or "tool", max_len=max_len)
|
||||||
|
parts.append(f"建议名称: '{suggestion}'")
|
||||||
|
errors.append("; ".join(parts))
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
report = "\n".join(errors)
|
||||||
|
messages.append(report)
|
||||||
|
if strict:
|
||||||
|
raise ToolValidationError(report)
|
||||||
|
return False, messages
|
||||||
|
|
||||||
|
return True, messages
|
||||||
|
|
||||||
Reference in New Issue
Block a user