from __future__ import annotations import json import os from copy import deepcopy from dataclasses import dataclass, field from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple from langchain_core.language_models.chat_models import BaseChatModel from .utils import ensure_env_loaded try: # pydantic v2 from pydantic import PrivateAttr except Exception: # pragma: no cover # pydantic v1 fallback from pydantic.fields import PrivateAttr # type: ignore from langchain_core.utils.function_calling import convert_to_openai_tool def _env(name: str, default: Optional[str] = None) -> Optional[str]: v = os.getenv(name) return v if v else default def _coalesce(*vals): for v in vals: if v is not None: return v return None def _to_openai_messages(messages: List["BaseMessage"]) -> List[Dict[str, Any]]: from langchain_core.messages import ( AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage, FunctionMessage, ) out: List[Dict[str, Any]] = [] for m in messages: if isinstance(m, SystemMessage): out.append({"role": "system", "content": m.content}) elif isinstance(m, HumanMessage): out.append({"role": "user", "content": m.content}) elif isinstance(m, AIMessage): # If previous AI produced tool calls, pass them through entry: Dict[str, Any] = {"role": "assistant", "content": m.content} tcs = getattr(m, "tool_calls", None) or m.additional_kwargs.get("tool_calls") if hasattr(m, "additional_kwargs") else None if tcs: # Map to OpenAI's expected shape entry["tool_calls"] = [ { "id": tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None), "type": "function", "function": { "name": (tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", "")) or "", "arguments": json.dumps(tc.get("args") if isinstance(tc, dict) else getattr(tc, "args", {})), }, } for tc in (tcs or []) ] out.append(entry) elif isinstance(m, ToolMessage): # OpenAI-compatible "tool" role with name + id must be provided out.append( { "role": "tool", "content": m.content, "tool_call_id": m.tool_call_id, } ) elif isinstance(m, FunctionMessage): # Legacy function messages out.append( { "role": "tool", "content": m.content, "tool_call_id": m.name or "", } ) else: # Fallback out.append({"role": getattr(m, "role", "user"), "content": getattr(m, "content", str(m))}) return out def _normalize_params(schema: Any) -> Dict[str, Any]: """Normalize various schema-like inputs into a valid OpenAI JSON Schema.""" if not isinstance(schema, dict): return {"type": "object", "properties": {}, "required": []} if schema.get("type") == "object" and isinstance(schema.get("properties"), dict): out = dict(schema) elif "properties" in schema and isinstance(schema["properties"], dict): out = {"type": "object", **schema} elif all(isinstance(v, dict) for v in schema.values()): out = {"type": "object", "properties": schema} else: out = {"type": "object", "properties": {}} for k in ("$schema", "$defs", "definitions", "title"): out.pop(k, None) req = out.get("required", []) if not isinstance(req, list): req = list(req) if req else [] out["required"] = req return out def _convert_tool(tool: Any) -> Dict[str, Any]: """Convert LangChain/MCP tool to OpenAI JSON via official utility, with fallback normalization.""" try: t = convert_to_openai_tool(tool) fn = t.get("function", {}) fn["description"] = (fn.get("description") or "")[:1000] fn["parameters"] = _normalize_params(fn.get("parameters")) return {"type": "function", "function": fn} except Exception: pass name = getattr(tool, "name", None) or getattr(tool, "__name__", "tool") desc = getattr(tool, "description", "") or "" schema = None if getattr(tool, "args_schema", None) is not None: try: schema = tool.args_schema.model_json_schema() except Exception: try: schema = tool.args_schema.schema() except Exception: schema = None if schema is None and hasattr(tool, "args"): try: schema = tool.args except Exception: schema = None params = _normalize_params(schema) return {"type": "function", "function": {"name": name, "description": desc[:1000], "parameters": params}} def _convert_tools(tools: Optional[List[Any]]) -> Optional[List[Dict[str, Any]]]: if not tools: return None return [_convert_tool(t) for t in tools] def _parse_tool_calls(msg: Dict[str, Any]) -> Tuple[str, List[Dict[str, Any]]]: content = _coalesce(msg.get("content"), "") tool_calls_raw = msg.get("tool_calls") or [] parsed: List[Dict[str, Any]] = [] for tc in tool_calls_raw: fn = (tc or {}).get("function", {}) name = fn.get("name") or "" args_str = fn.get("arguments") or "{}" args = {} if isinstance(args_str, str): try: args = json.loads(args_str or "{}") except Exception: # keep raw if malformed; downstream may repair args = {"$raw": args_str} elif isinstance(args_str, dict): args = args_str parsed.append({ "id": tc.get("id"), "name": name, "args": args, }) return content, parsed @dataclass class _ToolCallBuilder: id: Optional[str] = None name: str = "" args_buf: List[str] = field(default_factory=list) def to_final(self) -> Dict[str, Any]: args_str = "".join(self.args_buf) if self.args_buf else "{}" try: args = json.loads(args_str or "{}") except Exception: args = {"$raw": args_str} return {"id": self.id, "name": self.name, "args": args} class ChatQwenOpenAICompat(BaseChatModel): """ A minimal LangChain BaseChatModel-compatible class that calls an OpenAI-compatible /chat/completions endpoint and assembles tool calls. Notes: - Implements _generate for non-stream and _stream for SSE streaming. - Stores bound tools and tool_choice to send with requests. - Avoids heavy dependencies; uses httpx at runtime. """ # Private state to avoid pydantic field validation _model: str = PrivateAttr() _api_key: str = PrivateAttr() _base_url: str = PrivateAttr() _temperature: Optional[float] = PrivateAttr(default=None) _max_tokens: Optional[int] = PrivateAttr(default=None) _extra_body: Dict[str, Any] = PrivateAttr(default_factory=dict) _timeout: float = PrivateAttr(default=60.0) _tools: Optional[List[Dict[str, Any]]] = PrivateAttr(default=None) _tool_choice: Optional[str] = PrivateAttr(default=None) def __init__( self, *, model: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, extra_body: Optional[Dict[str, Any]] = None, timeout: float = 60.0, ) -> None: super().__init__() # Load .env variables once if available ensure_env_loaded() # Store config in private attrs self._model = model or _env("QWEN_MODEL", "qwen3-coder-flash") self._api_key = ( api_key or _env("QWEN_API_KEY") or _env("GPUSTACK_API_KEY") or _env("OPENAI_API_KEY") or _env("DASHSCOPE_API_KEY") ) if not self._api_key: raise ValueError("Missing API key: set QWEN_API_KEY/OPENAI_API_KEY/DASHSCOPE_API_KEY") self._base_url = base_url or _env("QWEN_BASE_URL") or _env("OPENAI_BASE_URL") if not self._base_url: # Default to DashScope intl OpenAI-compatible self._base_url = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1" # Allow env override for timeout timeout_env = _env("QWEN_TIMEOUT") if timeout_env: try: timeout = float(timeout_env) except Exception: pass self._temperature = temperature self._max_tokens = max_tokens self._extra_body = extra_body or {} self._timeout = timeout # tools/tool_choice already defaulted via PrivateAttr declarations # Expose LangChain binding-style API def bind_tools(self, tools: List[Any]) -> "ChatQwenOpenAICompat": clone = self._copy() clone._tools = _convert_tools(tools) return clone def bind(self, **kwargs: Any) -> "ChatQwenOpenAICompat": clone = self._copy() # Allow passing tool_choice or extra_body overrides if "tool_choice" in kwargs: clone._tool_choice = kwargs["tool_choice"] if "extra_body" in kwargs and isinstance(kwargs["extra_body"], dict): eb = dict(clone._extra_body) eb.update(kwargs["extra_body"]) # type: ignore clone._extra_body = eb return clone # LangChain BaseChatModel required interface @property def _llm_type(self) -> str: # type: ignore[override] return "qwen-openai-compat" def _copy(self) -> "ChatQwenOpenAICompat": c = ChatQwenOpenAICompat( model=self._model, api_key=self._api_key, base_url=self._base_url, temperature=self._temperature, max_tokens=self._max_tokens, extra_body=dict(self._extra_body), timeout=self._timeout, ) c._tools = self._tools[:] if self._tools else None c._tool_choice = self._tool_choice return c # -- Internal HTTP helpers -- def _request(self, payload: Dict[str, Any], *, stream: bool = False) -> Any: try: import httpx import sys as _sys except Exception as e: raise RuntimeError("httpx is required for ChatQwenOpenAICompat. Install via `pip install httpx`. ") from e # Auth header customization (for non-standard gateways) auth_header = _env("QWEN_AUTH_HEADER", "Authorization") auth_scheme = _env("QWEN_AUTH_SCHEME", "Bearer") auth_value = f"{auth_scheme} {self._api_key}" if auth_scheme else self._api_key headers = {auth_header: auth_value, "Content-Type": "application/json"} base = self._base_url.rstrip("/") url = base if "chat/completions" in base else base + "/chat/completions" # honor system proxy by default; disable via QWEN_HTTP_TRUST_ENV=0 trust_env = _env("QWEN_HTTP_TRUST_ENV", "1") != "0" if _env("QWEN_DEBUG") == "1": print(f"[qwen] POST {url} timeout={self._timeout} trust_env={trust_env}", file=_sys.stderr) # ===== DEBUG: dump payload & headers (API key redacted) ===== if _env("QWEN_DEBUG_BODY") == "1": try: # 打印请求体 print("[qwen] payload =\n" + json.dumps(payload, ensure_ascii=False, indent=2), file=_sys.stderr) except Exception: print("[qwen] payload = ", file=_sys.stderr) try: # 打印打码后的 headers redacted = dict(headers) if auth_header in redacted and isinstance(redacted[auth_header], str): # 只保留认证方案,密钥打码 val = redacted[auth_header] if " " in val: scheme, _token = val.split(" ", 1) redacted[auth_header] = f"{scheme} ***REDACTED***" else: redacted[auth_header] = "***REDACTED***" print("[qwen] headers = " + json.dumps(redacted, ensure_ascii=False), file=_sys.stderr) except Exception: pass # =========================================================== client = httpx.Client(timeout=self._timeout, trust_env=trust_env) if stream: # 流式时也会先把 payload 打出来(上面已做),这里直接返回流 return client.stream("POST", url, headers=headers, json=payload) resp = client.post(url, headers=headers, json=payload) # ===== DEBUG: dump response ===== if _env("QWEN_DEBUG_RESP") == "1": print(f"[qwen] status={resp.status_code}", file=_sys.stderr) try: print("[qwen] resp.text =\n" + resp.text, file=_sys.stderr) except Exception: print("[qwen] resp.text = ", file=_sys.stderr) # ================================= try: resp.raise_for_status() except httpx.HTTPStatusError as e: # 打印错误体,帮助定位 4xx/5xx 的具体原因 if _env("QWEN_DEBUG_RESP", "1") == "1": try: print(f"[qwen] ERROR status={e.response.status_code}\n{e.response.text}", file=_sys.stderr) except Exception: pass raise return resp.json() def _build_payload(self, messages: List["BaseMessage"], **kwargs: Any) -> Dict[str, Any]: payload: Dict[str, Any] = { "model": self._model, "messages": _to_openai_messages(messages), } if self._temperature is not None: payload["temperature"] = self._temperature if self._max_tokens is not None: payload["max_tokens"] = self._max_tokens if self._tools: safe_tools: List[Dict[str, Any]] = [] for t in self._tools: if isinstance(t, dict) and t.get("type") == "function" and isinstance(t.get("function"), dict): fn = deepcopy(t["function"]) # type: ignore fn["description"] = (fn.get("description") or "")[:1000] fn["parameters"] = _normalize_params(fn.get("parameters")) safe_tools.append({"type": "function", "function": fn}) else: safe_tools.append(_convert_tool(t)) payload["tools"] = safe_tools if self._tool_choice is not None: payload["tool_choice"] = self._tool_choice if self._extra_body: payload.update(self._extra_body) if kwargs: allowed = { "tool_choice","temperature","top_p","max_tokens","stop","stream","n","logprobs","logit_bias","user","presence_penalty","frequency_penalty","seed","parallel_tool_calls","response_format" } for k, v in kwargs.items(): if k in allowed and v is not None: payload[k] = v return payload # -- Non-streaming generate -- def _generate( self, messages: List["BaseMessage"], stop: Optional[List[str]] = None, run_manager: Optional[Any] = None, **kwargs: Any, ) -> "ChatResult": # type: ignore[name-defined] from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.messages import AIMessage payload = self._build_payload(messages, **kwargs) data = self._request(payload, stream=False) choice = (data.get("choices") or [{}])[0] msg = choice.get("message", {}) content, tool_calls = _parse_tool_calls(msg) usage = data.get("usage", {}) extra_kwargs = { "raw": data, "reasoning_content": msg.get("reasoning_content"), } msg_kwargs: Dict[str, Any] = {"content": content or "", "additional_kwargs": extra_kwargs} if tool_calls: msg_kwargs["tool_calls"] = tool_calls ai = AIMessage(**msg_kwargs) gen = ChatGeneration(message=ai) return ChatResult(generations=[gen], llm_output={"usage": usage, "model": data.get("model")}) # -- Streaming generate -- def _stream( self, messages: List["BaseMessage"], stop: Optional[List[str]] = None, run_manager: Optional[Any] = None, **kwargs: Any, ) -> Iterator["ChatGenerationChunk"]: # type: ignore[name-defined] from langchain_core.outputs import ChatGenerationChunk from langchain_core.messages import AIMessageChunk payload = self._build_payload(messages, stream=True, **kwargs) with self._request(payload, stream=True) as r: # type: ignore[attr-defined] # SSE stream; lines like: data: {...}\n\n; end with data: [DONE] text_buf: List[str] = [] tc_builders: List[_ToolCallBuilder] = [] for line in r.iter_lines(): # type: ignore[attr-defined] if not line: continue if isinstance(line, (bytes, bytearray)): try: line = line.decode("utf-8", errors="ignore") except Exception: continue if not line.startswith("data:"): continue data = line[len("data:"):].strip() if data == "[DONE]": # Finalize: emit a chunk with final tool calls if any if tc_builders or text_buf: extra: Dict[str, Any] = {} if tc_builders: extra["tool_calls"] = [b.to_final() for b in tc_builders] final = AIMessageChunk( content="".join(text_buf) if text_buf else "", additional_kwargs=extra if extra else None, ) yield ChatGenerationChunk(message=final) break try: obj = json.loads(data) except Exception: continue choices = obj.get("choices") or [] if not choices: continue delta = (choices[0].get("delta") or {}) # Text deltas dtext = delta.get("content") if dtext: text_buf.append(dtext) yield ChatGenerationChunk(message=AIMessageChunk(content=dtext)) # Tool call deltas dtools = delta.get("tool_calls") or [] for i, part in enumerate(dtools): # Ensure list length while len(tc_builders) <= i: tc_builders.append(_ToolCallBuilder()) b = tc_builders[i] if "id" in part: b.id = part.get("id") or b.id fn = part.get("function") or {} if fn.get("name"): b.name = fn["name"] if fn.get("arguments"): b.args_buf.append(fn["arguments"]) # partial JSON string