langgraph adapter

This commit is contained in:
lingyuzeng
2025-08-30 23:29:26 +08:00
parent 273e6088fc
commit e1a36bc1e6
6 changed files with 703 additions and 0 deletions

View File

@@ -0,0 +1,414 @@
from __future__ import annotations
import json
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
from langchain_core.language_models.chat_models import BaseChatModel
from .utils import ensure_env_loaded
try:
# pydantic v2
from pydantic import PrivateAttr
except Exception: # pragma: no cover
# pydantic v1 fallback
from pydantic.fields import PrivateAttr # type: ignore
def _env(name: str, default: Optional[str] = None) -> Optional[str]:
v = os.getenv(name)
return v if v else default
def _coalesce(*vals):
for v in vals:
if v is not None:
return v
return None
def _to_openai_messages(messages: List["BaseMessage"]) -> List[Dict[str, Any]]:
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
FunctionMessage,
)
out: List[Dict[str, Any]] = []
for m in messages:
if isinstance(m, SystemMessage):
out.append({"role": "system", "content": m.content})
elif isinstance(m, HumanMessage):
out.append({"role": "user", "content": m.content})
elif isinstance(m, AIMessage):
# If previous AI produced tool calls, pass them through
entry: Dict[str, Any] = {"role": "assistant", "content": m.content}
tcs = getattr(m, "tool_calls", None) or m.additional_kwargs.get("tool_calls") if hasattr(m, "additional_kwargs") else None
if tcs:
# Map to OpenAI's expected shape
entry["tool_calls"] = [
{
"id": tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None),
"type": "function",
"function": {
"name": (tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", "")) or "",
"arguments": json.dumps(tc.get("args") if isinstance(tc, dict) else getattr(tc, "args", {})),
},
}
for tc in (tcs or [])
]
out.append(entry)
elif isinstance(m, ToolMessage):
# OpenAI-compatible "tool" role with name + id must be provided
out.append(
{
"role": "tool",
"content": m.content,
"tool_call_id": m.tool_call_id,
}
)
elif isinstance(m, FunctionMessage):
# Legacy function messages
out.append(
{
"role": "tool",
"content": m.content,
"tool_call_id": m.name or "",
}
)
else:
# Fallback
out.append({"role": getattr(m, "role", "user"), "content": getattr(m, "content", str(m))})
return out
def _convert_tool(tool: Any) -> Dict[str, Any]:
"""Best-effort conversion to OpenAI tool schema."""
# Prefer built-in conversion if available
if hasattr(tool, "to_openai_tool"):
try:
return tool.to_openai_tool() # type: ignore[attr-defined]
except Exception:
pass
name = getattr(tool, "name", None) or getattr(tool, "__name__", "tool")
description = getattr(tool, "description", "")
# Attempt to get a JSON schema
params: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
schema = None
if hasattr(tool, "args_schema") and getattr(tool, "args_schema") is not None:
try:
schema = tool.args_schema.schema() # pydantic v1
except Exception:
try:
schema = tool.args_schema.model_json_schema() # pydantic v2
except Exception:
schema = None
if schema is None and hasattr(tool, "args"):
try:
schema = tool.args
except Exception:
schema = None
if isinstance(schema, dict):
params = schema
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": params,
},
}
def _convert_tools(tools: Optional[List[Any]]) -> Optional[List[Dict[str, Any]]]:
if not tools:
return None
return [_convert_tool(t) for t in tools]
def _parse_tool_calls(msg: Dict[str, Any]) -> Tuple[str, List[Dict[str, Any]]]:
content = _coalesce(msg.get("content"), "")
tool_calls_raw = msg.get("tool_calls") or []
parsed: List[Dict[str, Any]] = []
for tc in tool_calls_raw:
fn = (tc or {}).get("function", {})
name = fn.get("name") or ""
args_str = fn.get("arguments") or "{}"
args = {}
if isinstance(args_str, str):
try:
args = json.loads(args_str or "{}")
except Exception:
# keep raw if malformed; downstream may repair
args = {"$raw": args_str}
elif isinstance(args_str, dict):
args = args_str
parsed.append({
"id": tc.get("id"),
"name": name,
"args": args,
})
return content, parsed
@dataclass
class _ToolCallBuilder:
id: Optional[str] = None
name: str = ""
args_buf: List[str] = field(default_factory=list)
def to_final(self) -> Dict[str, Any]:
args_str = "".join(self.args_buf) if self.args_buf else "{}"
try:
args = json.loads(args_str or "{}")
except Exception:
args = {"$raw": args_str}
return {"id": self.id, "name": self.name, "args": args}
class ChatQwenOpenAICompat(BaseChatModel):
"""
A minimal LangChain BaseChatModel-compatible class that calls an
OpenAI-compatible /chat/completions endpoint and assembles tool calls.
Notes:
- Implements _generate for non-stream and _stream for SSE streaming.
- Stores bound tools and tool_choice to send with requests.
- Avoids heavy dependencies; uses httpx at runtime.
"""
# Private state to avoid pydantic field validation
_model: str = PrivateAttr()
_api_key: str = PrivateAttr()
_base_url: str = PrivateAttr()
_temperature: Optional[float] = PrivateAttr(default=None)
_max_tokens: Optional[int] = PrivateAttr(default=None)
_extra_body: Dict[str, Any] = PrivateAttr(default_factory=dict)
_timeout: float = PrivateAttr(default=60.0)
_tools: Optional[List[Dict[str, Any]]] = PrivateAttr(default=None)
_tool_choice: Optional[str] = PrivateAttr(default=None)
def __init__(
self,
*,
model: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: float = 60.0,
) -> None:
super().__init__()
# Load .env variables once if available
ensure_env_loaded()
# Store config in private attrs
self._model = model or _env("QWEN_MODEL", "qwen3-coder-flash")
self._api_key = (
api_key
or _env("QWEN_API_KEY")
or _env("GPUSTACK_API_KEY")
or _env("OPENAI_API_KEY")
or _env("DASHSCOPE_API_KEY")
)
if not self._api_key:
raise ValueError("Missing API key: set QWEN_API_KEY/OPENAI_API_KEY/DASHSCOPE_API_KEY")
self._base_url = base_url or _env("QWEN_BASE_URL") or _env("OPENAI_BASE_URL")
if not self._base_url:
# Default to DashScope intl OpenAI-compatible
self._base_url = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
# Allow env override for timeout
timeout_env = _env("QWEN_TIMEOUT")
if timeout_env:
try:
timeout = float(timeout_env)
except Exception:
pass
self._temperature = temperature
self._max_tokens = max_tokens
self._extra_body = extra_body or {}
self._timeout = timeout
# tools/tool_choice already defaulted via PrivateAttr declarations
# Expose LangChain binding-style API
def bind_tools(self, tools: List[Any]) -> "ChatQwenOpenAICompat":
clone = self._copy()
clone._tools = _convert_tools(tools)
return clone
def bind(self, **kwargs: Any) -> "ChatQwenOpenAICompat":
clone = self._copy()
# Allow passing tool_choice or extra_body overrides
if "tool_choice" in kwargs:
clone._tool_choice = kwargs["tool_choice"]
if "extra_body" in kwargs and isinstance(kwargs["extra_body"], dict):
eb = dict(clone._extra_body)
eb.update(kwargs["extra_body"]) # type: ignore
clone._extra_body = eb
return clone
# LangChain BaseChatModel required interface
@property
def _llm_type(self) -> str: # type: ignore[override]
return "qwen-openai-compat"
def _copy(self) -> "ChatQwenOpenAICompat":
c = ChatQwenOpenAICompat(
model=self._model,
api_key=self._api_key,
base_url=self._base_url,
temperature=self._temperature,
max_tokens=self._max_tokens,
extra_body=dict(self._extra_body),
timeout=self._timeout,
)
c._tools = self._tools[:] if self._tools else None
c._tool_choice = self._tool_choice
return c
# -- Internal HTTP helpers --
def _request(self, payload: Dict[str, Any], *, stream: bool = False) -> Any:
try:
import httpx
except Exception as e:
raise RuntimeError("httpx is required for ChatQwenOpenAICompat. Install via `pip install httpx`. ") from e
# Auth header customization (for non-standard gateways)
auth_header = _env("QWEN_AUTH_HEADER", "Authorization")
auth_scheme = _env("QWEN_AUTH_SCHEME", "Bearer")
auth_value = f"{auth_scheme} {self._api_key}" if auth_scheme else self._api_key
headers = {auth_header: auth_value, "Content-Type": "application/json"}
base = self._base_url.rstrip("/")
# Accept either a root base_url (ending with /v1) or a full endpoint that already includes /chat/completions
if "chat/completions" in base:
url = base
else:
url = base + "/chat/completions"
# Proxy behavior: honor env by default; allow disabling via QWEN_HTTP_TRUST_ENV=0
trust_env = _env("QWEN_HTTP_TRUST_ENV", "1") != "0"
client = httpx.Client(timeout=self._timeout, trust_env=trust_env)
if _env("QWEN_DEBUG") == "1":
import sys as _sys
print(f"[qwen] POST {url} timeout={self._timeout} trust_env={trust_env}", file=_sys.stderr)
if stream:
return client.stream("POST", url, headers=headers, json=payload)
resp = client.post(url, headers=headers, json=payload)
resp.raise_for_status()
return resp.json()
def _build_payload(self, messages: List["BaseMessage"], **kwargs: Any) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"model": self._model,
"messages": _to_openai_messages(messages),
}
if self._temperature is not None:
payload["temperature"] = self._temperature
if self._max_tokens is not None:
payload["max_tokens"] = self._max_tokens
if self._tools:
payload["tools"] = self._tools
if self._tool_choice is not None:
payload["tool_choice"] = self._tool_choice
if self._extra_body:
payload.update(self._extra_body)
if kwargs:
payload.update(kwargs)
return payload
# -- Non-streaming generate --
def _generate(
self,
messages: List["BaseMessage"],
stop: Optional[List[str]] = None,
run_manager: Optional[Any] = None,
**kwargs: Any,
) -> "ChatResult": # type: ignore[name-defined]
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.messages import AIMessage
payload = self._build_payload(messages, **kwargs)
data = self._request(payload, stream=False)
choice = (data.get("choices") or [{}])[0]
msg = choice.get("message", {})
content, tool_calls = _parse_tool_calls(msg)
usage = data.get("usage", {})
extra_kwargs = {
"raw": data,
"reasoning_content": msg.get("reasoning_content"),
}
msg_kwargs: Dict[str, Any] = {"content": content or "", "additional_kwargs": extra_kwargs}
if tool_calls:
msg_kwargs["tool_calls"] = tool_calls
ai = AIMessage(**msg_kwargs)
gen = ChatGeneration(message=ai)
return ChatResult(generations=[gen], llm_output={"usage": usage, "model": data.get("model")})
# -- Streaming generate --
def _stream(
self,
messages: List["BaseMessage"],
stop: Optional[List[str]] = None,
run_manager: Optional[Any] = None,
**kwargs: Any,
) -> Iterator["ChatGenerationChunk"]: # type: ignore[name-defined]
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.messages import AIMessageChunk
payload = self._build_payload(messages, stream=True, **kwargs)
with self._request(payload, stream=True) as r: # type: ignore[attr-defined]
# SSE stream; lines like: data: {...}\n\n; end with data: [DONE]
text_buf: List[str] = []
tc_builders: List[_ToolCallBuilder] = []
for line in r.iter_lines(): # type: ignore[attr-defined]
if not line:
continue
if isinstance(line, (bytes, bytearray)):
try:
line = line.decode("utf-8", errors="ignore")
except Exception:
continue
if not line.startswith("data:"):
continue
data = line[len("data:"):].strip()
if data == "[DONE]":
# Finalize: emit a chunk with final tool calls if any
if tc_builders or text_buf:
extra: Dict[str, Any] = {}
if tc_builders:
extra["tool_calls"] = [b.to_final() for b in tc_builders]
final = AIMessageChunk(
content="".join(text_buf) if text_buf else "",
additional_kwargs=extra if extra else None,
)
yield ChatGenerationChunk(message=final)
break
try:
obj = json.loads(data)
except Exception:
continue
choices = obj.get("choices") or []
if not choices:
continue
delta = (choices[0].get("delta") or {})
# Text deltas
dtext = delta.get("content")
if dtext:
text_buf.append(dtext)
yield ChatGenerationChunk(message=AIMessageChunk(content=dtext))
# Tool call deltas
dtools = delta.get("tool_calls") or []
for i, part in enumerate(dtools):
# Ensure list length
while len(tc_builders) <= i:
tc_builders.append(_ToolCallBuilder())
b = tc_builders[i]
if "id" in part:
b.id = part.get("id") or b.id
fn = part.get("function") or {}
if fn.get("name"):
b.name = fn["name"]
if fn.get("arguments"):
b.args_buf.append(fn["arguments"]) # partial JSON string