diff --git a/langgraph_qwen/__init__.py b/langgraph_qwen/__init__.py new file mode 100644 index 0000000..9559f2b --- /dev/null +++ b/langgraph_qwen/__init__.py @@ -0,0 +1,29 @@ +""" +Lightweight adapter to use Qwen3-Coder(-Flash) models with LangGraph via +LangChain ChatModels. Focuses on an OpenAI-compatible path and a clean +factory API, plus helpers for tool binding and streaming examples. + +Usage quickstart: + + from langgraph_qwen.factory import get_qwen_chat_model + from langgraph.prebuilt import create_react_agent + + model = get_qwen_chat_model() + agent = create_react_agent(model, tools=[...]) + +Environment variables: +- QWEN_API_KEY: API key for OpenAI-compatible endpoint (or DashScope) +- QWEN_BASE_URL: Base URL (e.g. https://dashscope-intl.aliyuncs.com/compatible-mode/v1) +- QWEN_MODEL: Model name (e.g. qwen3-coder-flash) +""" + +__all__ = [ + "get_qwen_chat_model", + "bind_qwen_tools", + "ChatQwenOpenAICompat", + "create_qwen_react_agent", +] + +from .factory import get_qwen_chat_model, bind_qwen_tools +from .chat_model import ChatQwenOpenAICompat +from .prebuilt import create_qwen_react_agent diff --git a/langgraph_qwen/chat_model.py b/langgraph_qwen/chat_model.py new file mode 100644 index 0000000..b22f4b6 --- /dev/null +++ b/langgraph_qwen/chat_model.py @@ -0,0 +1,414 @@ +from __future__ import annotations + +import json +import os +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple +from langchain_core.language_models.chat_models import BaseChatModel +from .utils import ensure_env_loaded +try: + # pydantic v2 + from pydantic import PrivateAttr +except Exception: # pragma: no cover + # pydantic v1 fallback + from pydantic.fields import PrivateAttr # type: ignore + + +def _env(name: str, default: Optional[str] = None) -> Optional[str]: + v = os.getenv(name) + return v if v else default + + +def _coalesce(*vals): + for v in vals: + if v is not None: + return v + return None + + +def _to_openai_messages(messages: List["BaseMessage"]) -> List[Dict[str, Any]]: + from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, + FunctionMessage, + ) + + out: List[Dict[str, Any]] = [] + for m in messages: + if isinstance(m, SystemMessage): + out.append({"role": "system", "content": m.content}) + elif isinstance(m, HumanMessage): + out.append({"role": "user", "content": m.content}) + elif isinstance(m, AIMessage): + # If previous AI produced tool calls, pass them through + entry: Dict[str, Any] = {"role": "assistant", "content": m.content} + tcs = getattr(m, "tool_calls", None) or m.additional_kwargs.get("tool_calls") if hasattr(m, "additional_kwargs") else None + if tcs: + # Map to OpenAI's expected shape + entry["tool_calls"] = [ + { + "id": tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None), + "type": "function", + "function": { + "name": (tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", "")) or "", + "arguments": json.dumps(tc.get("args") if isinstance(tc, dict) else getattr(tc, "args", {})), + }, + } + for tc in (tcs or []) + ] + out.append(entry) + elif isinstance(m, ToolMessage): + # OpenAI-compatible "tool" role with name + id must be provided + out.append( + { + "role": "tool", + "content": m.content, + "tool_call_id": m.tool_call_id, + } + ) + elif isinstance(m, FunctionMessage): + # Legacy function messages + out.append( + { + "role": "tool", + "content": m.content, + "tool_call_id": m.name or "", + } + ) + else: + # Fallback + out.append({"role": getattr(m, "role", "user"), "content": getattr(m, "content", str(m))}) + return out + + +def _convert_tool(tool: Any) -> Dict[str, Any]: + """Best-effort conversion to OpenAI tool schema.""" + # Prefer built-in conversion if available + if hasattr(tool, "to_openai_tool"): + try: + return tool.to_openai_tool() # type: ignore[attr-defined] + except Exception: + pass + name = getattr(tool, "name", None) or getattr(tool, "__name__", "tool") + description = getattr(tool, "description", "") + # Attempt to get a JSON schema + params: Dict[str, Any] = {"type": "object", "properties": {}, "required": []} + schema = None + if hasattr(tool, "args_schema") and getattr(tool, "args_schema") is not None: + try: + schema = tool.args_schema.schema() # pydantic v1 + except Exception: + try: + schema = tool.args_schema.model_json_schema() # pydantic v2 + except Exception: + schema = None + if schema is None and hasattr(tool, "args"): + try: + schema = tool.args + except Exception: + schema = None + if isinstance(schema, dict): + params = schema + return { + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": params, + }, + } + + +def _convert_tools(tools: Optional[List[Any]]) -> Optional[List[Dict[str, Any]]]: + if not tools: + return None + return [_convert_tool(t) for t in tools] + + +def _parse_tool_calls(msg: Dict[str, Any]) -> Tuple[str, List[Dict[str, Any]]]: + content = _coalesce(msg.get("content"), "") + tool_calls_raw = msg.get("tool_calls") or [] + parsed: List[Dict[str, Any]] = [] + for tc in tool_calls_raw: + fn = (tc or {}).get("function", {}) + name = fn.get("name") or "" + args_str = fn.get("arguments") or "{}" + args = {} + if isinstance(args_str, str): + try: + args = json.loads(args_str or "{}") + except Exception: + # keep raw if malformed; downstream may repair + args = {"$raw": args_str} + elif isinstance(args_str, dict): + args = args_str + parsed.append({ + "id": tc.get("id"), + "name": name, + "args": args, + }) + return content, parsed + + +@dataclass +class _ToolCallBuilder: + id: Optional[str] = None + name: str = "" + args_buf: List[str] = field(default_factory=list) + + def to_final(self) -> Dict[str, Any]: + args_str = "".join(self.args_buf) if self.args_buf else "{}" + try: + args = json.loads(args_str or "{}") + except Exception: + args = {"$raw": args_str} + return {"id": self.id, "name": self.name, "args": args} + + +class ChatQwenOpenAICompat(BaseChatModel): + """ + A minimal LangChain BaseChatModel-compatible class that calls an + OpenAI-compatible /chat/completions endpoint and assembles tool calls. + + Notes: + - Implements _generate for non-stream and _stream for SSE streaming. + - Stores bound tools and tool_choice to send with requests. + - Avoids heavy dependencies; uses httpx at runtime. + """ + + # Private state to avoid pydantic field validation + _model: str = PrivateAttr() + _api_key: str = PrivateAttr() + _base_url: str = PrivateAttr() + _temperature: Optional[float] = PrivateAttr(default=None) + _max_tokens: Optional[int] = PrivateAttr(default=None) + _extra_body: Dict[str, Any] = PrivateAttr(default_factory=dict) + _timeout: float = PrivateAttr(default=60.0) + _tools: Optional[List[Dict[str, Any]]] = PrivateAttr(default=None) + _tool_choice: Optional[str] = PrivateAttr(default=None) + + def __init__( + self, + *, + model: Optional[str] = None, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: float = 60.0, + ) -> None: + super().__init__() + # Load .env variables once if available + ensure_env_loaded() + # Store config in private attrs + self._model = model or _env("QWEN_MODEL", "qwen3-coder-flash") + self._api_key = ( + api_key + or _env("QWEN_API_KEY") + or _env("GPUSTACK_API_KEY") + or _env("OPENAI_API_KEY") + or _env("DASHSCOPE_API_KEY") + ) + if not self._api_key: + raise ValueError("Missing API key: set QWEN_API_KEY/OPENAI_API_KEY/DASHSCOPE_API_KEY") + self._base_url = base_url or _env("QWEN_BASE_URL") or _env("OPENAI_BASE_URL") + if not self._base_url: + # Default to DashScope intl OpenAI-compatible + self._base_url = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1" + # Allow env override for timeout + timeout_env = _env("QWEN_TIMEOUT") + if timeout_env: + try: + timeout = float(timeout_env) + except Exception: + pass + self._temperature = temperature + self._max_tokens = max_tokens + self._extra_body = extra_body or {} + self._timeout = timeout + # tools/tool_choice already defaulted via PrivateAttr declarations + + # Expose LangChain binding-style API + def bind_tools(self, tools: List[Any]) -> "ChatQwenOpenAICompat": + clone = self._copy() + clone._tools = _convert_tools(tools) + return clone + + def bind(self, **kwargs: Any) -> "ChatQwenOpenAICompat": + clone = self._copy() + # Allow passing tool_choice or extra_body overrides + if "tool_choice" in kwargs: + clone._tool_choice = kwargs["tool_choice"] + if "extra_body" in kwargs and isinstance(kwargs["extra_body"], dict): + eb = dict(clone._extra_body) + eb.update(kwargs["extra_body"]) # type: ignore + clone._extra_body = eb + return clone + + # LangChain BaseChatModel required interface + @property + def _llm_type(self) -> str: # type: ignore[override] + return "qwen-openai-compat" + + def _copy(self) -> "ChatQwenOpenAICompat": + c = ChatQwenOpenAICompat( + model=self._model, + api_key=self._api_key, + base_url=self._base_url, + temperature=self._temperature, + max_tokens=self._max_tokens, + extra_body=dict(self._extra_body), + timeout=self._timeout, + ) + c._tools = self._tools[:] if self._tools else None + c._tool_choice = self._tool_choice + return c + + # -- Internal HTTP helpers -- + def _request(self, payload: Dict[str, Any], *, stream: bool = False) -> Any: + try: + import httpx + except Exception as e: + raise RuntimeError("httpx is required for ChatQwenOpenAICompat. Install via `pip install httpx`. ") from e + + # Auth header customization (for non-standard gateways) + auth_header = _env("QWEN_AUTH_HEADER", "Authorization") + auth_scheme = _env("QWEN_AUTH_SCHEME", "Bearer") + auth_value = f"{auth_scheme} {self._api_key}" if auth_scheme else self._api_key + headers = {auth_header: auth_value, "Content-Type": "application/json"} + base = self._base_url.rstrip("/") + # Accept either a root base_url (ending with /v1) or a full endpoint that already includes /chat/completions + if "chat/completions" in base: + url = base + else: + url = base + "/chat/completions" + # Proxy behavior: honor env by default; allow disabling via QWEN_HTTP_TRUST_ENV=0 + trust_env = _env("QWEN_HTTP_TRUST_ENV", "1") != "0" + client = httpx.Client(timeout=self._timeout, trust_env=trust_env) + if _env("QWEN_DEBUG") == "1": + import sys as _sys + print(f"[qwen] POST {url} timeout={self._timeout} trust_env={trust_env}", file=_sys.stderr) + if stream: + return client.stream("POST", url, headers=headers, json=payload) + resp = client.post(url, headers=headers, json=payload) + resp.raise_for_status() + return resp.json() + + def _build_payload(self, messages: List["BaseMessage"], **kwargs: Any) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "model": self._model, + "messages": _to_openai_messages(messages), + } + if self._temperature is not None: + payload["temperature"] = self._temperature + if self._max_tokens is not None: + payload["max_tokens"] = self._max_tokens + if self._tools: + payload["tools"] = self._tools + if self._tool_choice is not None: + payload["tool_choice"] = self._tool_choice + if self._extra_body: + payload.update(self._extra_body) + if kwargs: + payload.update(kwargs) + return payload + + # -- Non-streaming generate -- + def _generate( + self, + messages: List["BaseMessage"], + stop: Optional[List[str]] = None, + run_manager: Optional[Any] = None, + **kwargs: Any, + ) -> "ChatResult": # type: ignore[name-defined] + from langchain_core.outputs import ChatGeneration, ChatResult + from langchain_core.messages import AIMessage + + payload = self._build_payload(messages, **kwargs) + data = self._request(payload, stream=False) + choice = (data.get("choices") or [{}])[0] + msg = choice.get("message", {}) + content, tool_calls = _parse_tool_calls(msg) + usage = data.get("usage", {}) + + extra_kwargs = { + "raw": data, + "reasoning_content": msg.get("reasoning_content"), + } + msg_kwargs: Dict[str, Any] = {"content": content or "", "additional_kwargs": extra_kwargs} + if tool_calls: + msg_kwargs["tool_calls"] = tool_calls + ai = AIMessage(**msg_kwargs) + gen = ChatGeneration(message=ai) + return ChatResult(generations=[gen], llm_output={"usage": usage, "model": data.get("model")}) + + # -- Streaming generate -- + def _stream( + self, + messages: List["BaseMessage"], + stop: Optional[List[str]] = None, + run_manager: Optional[Any] = None, + **kwargs: Any, + ) -> Iterator["ChatGenerationChunk"]: # type: ignore[name-defined] + from langchain_core.outputs import ChatGenerationChunk + from langchain_core.messages import AIMessageChunk + + payload = self._build_payload(messages, stream=True, **kwargs) + with self._request(payload, stream=True) as r: # type: ignore[attr-defined] + # SSE stream; lines like: data: {...}\n\n; end with data: [DONE] + text_buf: List[str] = [] + tc_builders: List[_ToolCallBuilder] = [] + for line in r.iter_lines(): # type: ignore[attr-defined] + if not line: + continue + if isinstance(line, (bytes, bytearray)): + try: + line = line.decode("utf-8", errors="ignore") + except Exception: + continue + if not line.startswith("data:"): + continue + data = line[len("data:"):].strip() + if data == "[DONE]": + # Finalize: emit a chunk with final tool calls if any + if tc_builders or text_buf: + extra: Dict[str, Any] = {} + if tc_builders: + extra["tool_calls"] = [b.to_final() for b in tc_builders] + final = AIMessageChunk( + content="".join(text_buf) if text_buf else "", + additional_kwargs=extra if extra else None, + ) + yield ChatGenerationChunk(message=final) + break + try: + obj = json.loads(data) + except Exception: + continue + choices = obj.get("choices") or [] + if not choices: + continue + delta = (choices[0].get("delta") or {}) + # Text deltas + dtext = delta.get("content") + if dtext: + text_buf.append(dtext) + yield ChatGenerationChunk(message=AIMessageChunk(content=dtext)) + # Tool call deltas + dtools = delta.get("tool_calls") or [] + for i, part in enumerate(dtools): + # Ensure list length + while len(tc_builders) <= i: + tc_builders.append(_ToolCallBuilder()) + b = tc_builders[i] + if "id" in part: + b.id = part.get("id") or b.id + fn = part.get("function") or {} + if fn.get("name"): + b.name = fn["name"] + if fn.get("arguments"): + b.args_buf.append(fn["arguments"]) # partial JSON string diff --git a/langgraph_qwen/factory.py b/langgraph_qwen/factory.py new file mode 100644 index 0000000..ba2d803 --- /dev/null +++ b/langgraph_qwen/factory.py @@ -0,0 +1,135 @@ +import os +from typing import Any, Dict, List, Optional +from .utils import ensure_env_loaded + + +def _env(var: str, default: Optional[str] = None) -> Optional[str]: + v = os.getenv(var) + return v if v is not None and v != "" else default + + +def get_qwen_chat_model( + *, + model: Optional[str] = None, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + extra: Optional[Dict[str, Any]] = None, + prefer: str = "openai_compat", +): + """ + Return a LangChain ChatModel configured for Qwen3-Coder(-Flash). + + Tries the OpenAI-compatible client (langchain_openai.ChatOpenAI) by default. + If unavailable and prefer != 'openai_compat', will try DashScope + (langchain_community.chat_models.tongyi.ChatTongyi). + + This function does not import heavy deps at module import time; imports are + local to keep repo light. + """ + + model = model or _env("QWEN_MODEL", "qwen3-coder-flash") + # Load .env variables once if available + ensure_env_loaded() + api_key = ( + api_key + or _env("QWEN_API_KEY") + or _env("GPUSTACK_API_KEY") + or _env("OPENAI_API_KEY") + or _env("DASHSCOPE_API_KEY") + ) + base_url = base_url or _env("QWEN_BASE_URL") or _env("OPENAI_BASE_URL") + + # Common generation params + gen_kwargs: Dict[str, Any] = {} + if temperature is not None: + gen_kwargs["temperature"] = temperature + if max_tokens is not None: + gen_kwargs["max_tokens"] = max_tokens + if extra: + gen_kwargs.update(extra) + + err: Optional[Exception] = None + + if prefer == "custom": + try: + from .chat_model import ChatQwenOpenAICompat + + if api_key is None: + api_key = ( + _env("QWEN_API_KEY") + or _env("GPUSTACK_API_KEY") + or _env("OPENAI_API_KEY") + or _env("DASHSCOPE_API_KEY") + ) + chat = ChatQwenOpenAICompat( + model=model, + api_key=api_key, + base_url=base_url, + **gen_kwargs, + ) + return chat + except Exception as e: + raise RuntimeError(f"Custom adapter init failed: {e}") + + if prefer == "openai_compat": + try: + from langchain_openai import ChatOpenAI # type: ignore + + if api_key is None: + raise ValueError("QWEN_API_KEY/OPENAI_API_KEY/DASHSCOPE_API_KEY is required") + + # ChatOpenAI supports base_url override for OpenAI-compatible servers + chat = ChatOpenAI( + model=model, + api_key=api_key, + base_url=base_url, + **gen_kwargs, + ) + return chat + except Exception as e: + err = e + + # Fallback: DashScope (Tongyi) path via community provider + try: + from langchain_community.chat_models.tongyi import ChatTongyi # type: ignore + + # ChatTongyi uses DASHSCOPE_API_KEY from env by default; api_key arg optional + chat = ChatTongyi( + model=model, + dashscope_api_key=api_key, + **gen_kwargs, + ) + return chat + except Exception as e2: + raise RuntimeError( + "Failed to create Qwen ChatModel. Tried OpenAI-compatible and DashScope paths.\n" + f"Prefer='" + prefer + "' error: " + (str(err) if err else "None") + "\n" + + "DashScope path error: " + str(e2) + ) + + +def bind_qwen_tools( + model, tools: List[Any], *, tool_choice: Optional[str] = None +): + """ + Bind tools to a LangChain ChatModel with Qwen-friendly defaults. + + - tools: list of BaseTool or functions decorated with @tool + - tool_choice: 'auto' | 'required' | tool-name | None + Returns: a model instance with bound tools. + """ + + # Many LangChain chat models implement .bind_tools(...) + # We keep this thin and rely on LangChain's native tool schema handling. + bound = model.bind_tools(tools) + if tool_choice is not None: + # Some providers accept tool_choice via .bind; others via gen kwargs + # We try to set it via .bind when available. + try: + bound = bound.bind(tool_choice=tool_choice) + except Exception: + # If provider doesn't accept bind(tool_choice=..), attach to kwargs + bound = bound.bind(extra_body={"tool_choice": tool_choice}) + return bound diff --git a/langgraph_qwen/prebuilt.py b/langgraph_qwen/prebuilt.py new file mode 100644 index 0000000..6b21914 --- /dev/null +++ b/langgraph_qwen/prebuilt.py @@ -0,0 +1,36 @@ +from typing import Any, Dict, Iterable, List, Optional + +from .factory import get_qwen_chat_model, bind_qwen_tools +from .validators import validate_tool_names, ToolValidationError + + +def create_qwen_react_agent( + tools: List[Any], + *, + model: Optional[Any] = None, + prefer: str = "custom", + tool_choice: Optional[str] = "auto", + model_kwargs: Optional[Dict[str, Any]] = None, + enforce_tool_name_check: bool = True, + **agent_kwargs: Any, +): + """ + Sugar around langgraph.prebuilt.create_react_agent configured for Qwen. + + - If `model` is None, constructs one via get_qwen_chat_model(prefer=..., **model_kwargs). + - Binds tools and applies `tool_choice`. + - Returns a LangGraph ReAct agent ready to invoke/stream. + """ + + from langgraph.prebuilt import create_react_agent as _create + + model_kwargs = model_kwargs or {} + + # Strictly validate tool names before binding, if enabled + if enforce_tool_name_check: + validate_tool_names(tools, strict=True) + if model is None: + model = get_qwen_chat_model(prefer=prefer, **model_kwargs) + + model = bind_qwen_tools(model, tools, tool_choice=tool_choice) + return _create(model, tools, **agent_kwargs) diff --git a/langgraph_qwen/utils.py b/langgraph_qwen/utils.py new file mode 100644 index 0000000..4bafe21 --- /dev/null +++ b/langgraph_qwen/utils.py @@ -0,0 +1,22 @@ +_DOTENV_LOADED = False + + +def ensure_env_loaded() -> None: + global _DOTENV_LOADED + if _DOTENV_LOADED: + return + try: + # Lazy import to avoid hard failure if not installed + from dotenv import load_dotenv, find_dotenv # type: ignore + + path = find_dotenv(usecwd=True) + if path: + load_dotenv(dotenv_path=path, override=False) + else: + # Fall back to default discovery in CWD + load_dotenv(override=False) + except Exception: + # Silently ignore if python-dotenv is not available or errors out + pass + _DOTENV_LOADED = True + diff --git a/langgraph_qwen/validators.py b/langgraph_qwen/validators.py new file mode 100644 index 0000000..f050853 --- /dev/null +++ b/langgraph_qwen/validators.py @@ -0,0 +1,67 @@ +import re +import hashlib +from typing import Iterable, List, Tuple, Any + + +TOOL_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$") + + +class ToolValidationError(Exception): + pass + + +def sanitize_tool_name(name: str, max_len: int = 64) -> str: + safe = re.sub(r"[^a-zA-Z0-9_-]", "_", name or "tool") + if len(safe) > max_len: + suffix = hashlib.sha1(safe.encode()).hexdigest()[:8] + safe = safe[: max_len - 9] + "_" + suffix + return safe or "tool" + + +def _find_illegal_chars(name: str) -> List[str]: + return sorted(list(set(re.findall(r"[^a-zA-Z0-9_-]", name)))) + + +def _get_name(obj: Any) -> str: + return getattr(obj, "name", None) or getattr(obj, "tool_name", None) or getattr(obj, "__name__", "tool") + + +def validate_tool_names( + tools: Iterable[Any], *, strict: bool = True, max_len: int = 64 +) -> Tuple[bool, List[str]]: + """ + Validate that tool names conform to OpenAI-compatible constraints: + - Allowed chars: a-z, A-Z, 0-9, underscore, hyphen + - Max length: 64 + + Returns (ok, messages). If strict=True and invalid tools exist, + raises ToolValidationError with a detailed report. + """ + messages: List[str] = [] + errors: List[str] = [] + for idx, t in enumerate(tools): + name = _get_name(t) + if TOOL_NAME_PATTERN.match(name or ""): + continue + parts: List[str] = [f"第{idx}个工具名称不合法: '{name}'"] + if not name: + parts.append("原因: 为空") + else: + illegal = _find_illegal_chars(name) + if illegal: + parts.append(f"包含非法字符: {''.join(illegal)} (仅允许[a-zA-Z0-9_-])") + if len(name) > max_len: + parts.append(f"长度超限: {len(name)} > {max_len}") + suggestion = sanitize_tool_name(name or "tool", max_len=max_len) + parts.append(f"建议名称: '{suggestion}'") + errors.append("; ".join(parts)) + + if errors: + report = "\n".join(errors) + messages.append(report) + if strict: + raise ToolValidationError(report) + return False, messages + + return True, messages +