add convert_to_openai_tool

This commit is contained in:
2025-09-01 11:59:58 +08:00
parent 87a7f3f9df
commit 5055fddeb1

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import json import json
import os import os
from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
@@ -13,6 +14,8 @@ except Exception: # pragma: no cover
# pydantic v1 fallback # pydantic v1 fallback
from pydantic.fields import PrivateAttr # type: ignore from pydantic.fields import PrivateAttr # type: ignore
from langchain_core.utils.function_calling import convert_to_openai_tool
def _env(name: str, default: Optional[str] = None) -> Optional[str]: def _env(name: str, default: Optional[str] = None) -> Optional[str]:
v = os.getenv(name) v = os.getenv(name)
@@ -84,25 +87,46 @@ def _to_openai_messages(messages: List["BaseMessage"]) -> List[Dict[str, Any]]:
return out return out
def _normalize_params(schema: Any) -> Dict[str, Any]:
"""Normalize various schema-like inputs into a valid OpenAI JSON Schema."""
if not isinstance(schema, dict):
return {"type": "object", "properties": {}, "required": []}
if schema.get("type") == "object" and isinstance(schema.get("properties"), dict):
out = dict(schema)
elif "properties" in schema and isinstance(schema["properties"], dict):
out = {"type": "object", **schema}
elif all(isinstance(v, dict) for v in schema.values()):
out = {"type": "object", "properties": schema}
else:
out = {"type": "object", "properties": {}}
for k in ("$schema", "$defs", "definitions", "title"):
out.pop(k, None)
req = out.get("required", [])
if not isinstance(req, list):
req = list(req) if req else []
out["required"] = req
return out
def _convert_tool(tool: Any) -> Dict[str, Any]: def _convert_tool(tool: Any) -> Dict[str, Any]:
"""Best-effort conversion to OpenAI tool schema.""" """Convert LangChain/MCP tool to OpenAI JSON via official utility, with fallback normalization."""
# Prefer built-in conversion if available
if hasattr(tool, "to_openai_tool"):
try: try:
return tool.to_openai_tool() # type: ignore[attr-defined] t = convert_to_openai_tool(tool)
fn = t.get("function", {})
fn["description"] = (fn.get("description") or "")[:1000]
fn["parameters"] = _normalize_params(fn.get("parameters"))
return {"type": "function", "function": fn}
except Exception: except Exception:
pass pass
name = getattr(tool, "name", None) or getattr(tool, "__name__", "tool") name = getattr(tool, "name", None) or getattr(tool, "__name__", "tool")
description = getattr(tool, "description", "") desc = getattr(tool, "description", "") or ""
# Attempt to get a JSON schema
params: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
schema = None schema = None
if hasattr(tool, "args_schema") and getattr(tool, "args_schema") is not None: if getattr(tool, "args_schema", None) is not None:
try: try:
schema = tool.args_schema.schema() # pydantic v1 schema = tool.args_schema.model_json_schema()
except Exception: except Exception:
try: try:
schema = tool.args_schema.model_json_schema() # pydantic v2 schema = tool.args_schema.schema()
except Exception: except Exception:
schema = None schema = None
if schema is None and hasattr(tool, "args"): if schema is None and hasattr(tool, "args"):
@@ -110,16 +134,8 @@ def _convert_tool(tool: Any) -> Dict[str, Any]:
schema = tool.args schema = tool.args
except Exception: except Exception:
schema = None schema = None
if isinstance(schema, dict): params = _normalize_params(schema)
params = schema return {"type": "function", "function": {"name": name, "description": desc[:1000], "parameters": params}}
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": params,
},
}
def _convert_tools(tools: Optional[List[Any]]) -> Optional[List[Dict[str, Any]]]: def _convert_tools(tools: Optional[List[Any]]) -> Optional[List[Dict[str, Any]]]:
@@ -272,6 +288,7 @@ class ChatQwenOpenAICompat(BaseChatModel):
def _request(self, payload: Dict[str, Any], *, stream: bool = False) -> Any: def _request(self, payload: Dict[str, Any], *, stream: bool = False) -> Any:
try: try:
import httpx import httpx
import sys as _sys
except Exception as e: except Exception as e:
raise RuntimeError("httpx is required for ChatQwenOpenAICompat. Install via `pip install httpx`. ") from e raise RuntimeError("httpx is required for ChatQwenOpenAICompat. Install via `pip install httpx`. ") from e
@@ -280,24 +297,69 @@ class ChatQwenOpenAICompat(BaseChatModel):
auth_scheme = _env("QWEN_AUTH_SCHEME", "Bearer") auth_scheme = _env("QWEN_AUTH_SCHEME", "Bearer")
auth_value = f"{auth_scheme} {self._api_key}" if auth_scheme else self._api_key auth_value = f"{auth_scheme} {self._api_key}" if auth_scheme else self._api_key
headers = {auth_header: auth_value, "Content-Type": "application/json"} headers = {auth_header: auth_value, "Content-Type": "application/json"}
base = self._base_url.rstrip("/") base = self._base_url.rstrip("/")
# Accept either a root base_url (ending with /v1) or a full endpoint that already includes /chat/completions url = base if "chat/completions" in base else base + "/chat/completions"
if "chat/completions" in base:
url = base # honor system proxy by default; disable via QWEN_HTTP_TRUST_ENV=0
else:
url = base + "/chat/completions"
# Proxy behavior: honor env by default; allow disabling via QWEN_HTTP_TRUST_ENV=0
trust_env = _env("QWEN_HTTP_TRUST_ENV", "1") != "0" trust_env = _env("QWEN_HTTP_TRUST_ENV", "1") != "0"
client = httpx.Client(timeout=self._timeout, trust_env=trust_env)
if _env("QWEN_DEBUG") == "1": if _env("QWEN_DEBUG") == "1":
import sys as _sys
print(f"[qwen] POST {url} timeout={self._timeout} trust_env={trust_env}", file=_sys.stderr) print(f"[qwen] POST {url} timeout={self._timeout} trust_env={trust_env}", file=_sys.stderr)
# ===== DEBUG: dump payload & headers (API key redacted) =====
if _env("QWEN_DEBUG_BODY") == "1":
try:
# 打印请求体
print("[qwen] payload =\n" + json.dumps(payload, ensure_ascii=False, indent=2), file=_sys.stderr)
except Exception:
print("[qwen] payload = <non-serializable>", file=_sys.stderr)
try:
# 打印打码后的 headers
redacted = dict(headers)
if auth_header in redacted and isinstance(redacted[auth_header], str):
# 只保留认证方案,密钥打码
val = redacted[auth_header]
if " " in val:
scheme, _token = val.split(" ", 1)
redacted[auth_header] = f"{scheme} ***REDACTED***"
else:
redacted[auth_header] = "***REDACTED***"
print("[qwen] headers = " + json.dumps(redacted, ensure_ascii=False), file=_sys.stderr)
except Exception:
pass
# ===========================================================
client = httpx.Client(timeout=self._timeout, trust_env=trust_env)
if stream: if stream:
# 流式时也会先把 payload 打出来(上面已做),这里直接返回流
return client.stream("POST", url, headers=headers, json=payload) return client.stream("POST", url, headers=headers, json=payload)
resp = client.post(url, headers=headers, json=payload) resp = client.post(url, headers=headers, json=payload)
# ===== DEBUG: dump response =====
if _env("QWEN_DEBUG_RESP") == "1":
print(f"[qwen] status={resp.status_code}", file=_sys.stderr)
try:
print("[qwen] resp.text =\n" + resp.text, file=_sys.stderr)
except Exception:
print("[qwen] resp.text = <binary or non-text>", file=_sys.stderr)
# =================================
try:
resp.raise_for_status() resp.raise_for_status()
except httpx.HTTPStatusError as e:
# 打印错误体,帮助定位 4xx/5xx 的具体原因
if _env("QWEN_DEBUG_RESP", "1") == "1":
try:
print(f"[qwen] ERROR status={e.response.status_code}\n{e.response.text}", file=_sys.stderr)
except Exception:
pass
raise
return resp.json() return resp.json()
def _build_payload(self, messages: List["BaseMessage"], **kwargs: Any) -> Dict[str, Any]: def _build_payload(self, messages: List["BaseMessage"], **kwargs: Any) -> Dict[str, Any]:
payload: Dict[str, Any] = { payload: Dict[str, Any] = {
"model": self._model, "model": self._model,
@@ -308,13 +370,27 @@ class ChatQwenOpenAICompat(BaseChatModel):
if self._max_tokens is not None: if self._max_tokens is not None:
payload["max_tokens"] = self._max_tokens payload["max_tokens"] = self._max_tokens
if self._tools: if self._tools:
payload["tools"] = self._tools safe_tools: List[Dict[str, Any]] = []
for t in self._tools:
if isinstance(t, dict) and t.get("type") == "function" and isinstance(t.get("function"), dict):
fn = deepcopy(t["function"]) # type: ignore
fn["description"] = (fn.get("description") or "")[:1000]
fn["parameters"] = _normalize_params(fn.get("parameters"))
safe_tools.append({"type": "function", "function": fn})
else:
safe_tools.append(_convert_tool(t))
payload["tools"] = safe_tools
if self._tool_choice is not None: if self._tool_choice is not None:
payload["tool_choice"] = self._tool_choice payload["tool_choice"] = self._tool_choice
if self._extra_body: if self._extra_body:
payload.update(self._extra_body) payload.update(self._extra_body)
if kwargs: if kwargs:
payload.update(kwargs) allowed = {
"tool_choice","temperature","top_p","max_tokens","stop","stream","n","logprobs","logit_bias","user","presence_penalty","frequency_penalty","seed","parallel_tool_calls","response_format"
}
for k, v in kwargs.items():
if k in allowed and v is not None:
payload[k] = v
return payload return payload
# -- Non-streaming generate -- # -- Non-streaming generate --