add convert_to_openai_tool
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
@@ -13,6 +14,8 @@ except Exception: # pragma: no cover
|
|||||||
# pydantic v1 fallback
|
# pydantic v1 fallback
|
||||||
from pydantic.fields import PrivateAttr # type: ignore
|
from pydantic.fields import PrivateAttr # type: ignore
|
||||||
|
|
||||||
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
|
||||||
|
|
||||||
def _env(name: str, default: Optional[str] = None) -> Optional[str]:
|
def _env(name: str, default: Optional[str] = None) -> Optional[str]:
|
||||||
v = os.getenv(name)
|
v = os.getenv(name)
|
||||||
@@ -84,25 +87,46 @@ def _to_openai_messages(messages: List["BaseMessage"]) -> List[Dict[str, Any]]:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_params(schema: Any) -> Dict[str, Any]:
|
||||||
|
"""Normalize various schema-like inputs into a valid OpenAI JSON Schema."""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return {"type": "object", "properties": {}, "required": []}
|
||||||
|
if schema.get("type") == "object" and isinstance(schema.get("properties"), dict):
|
||||||
|
out = dict(schema)
|
||||||
|
elif "properties" in schema and isinstance(schema["properties"], dict):
|
||||||
|
out = {"type": "object", **schema}
|
||||||
|
elif all(isinstance(v, dict) for v in schema.values()):
|
||||||
|
out = {"type": "object", "properties": schema}
|
||||||
|
else:
|
||||||
|
out = {"type": "object", "properties": {}}
|
||||||
|
for k in ("$schema", "$defs", "definitions", "title"):
|
||||||
|
out.pop(k, None)
|
||||||
|
req = out.get("required", [])
|
||||||
|
if not isinstance(req, list):
|
||||||
|
req = list(req) if req else []
|
||||||
|
out["required"] = req
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _convert_tool(tool: Any) -> Dict[str, Any]:
|
def _convert_tool(tool: Any) -> Dict[str, Any]:
|
||||||
"""Best-effort conversion to OpenAI tool schema."""
|
"""Convert LangChain/MCP tool to OpenAI JSON via official utility, with fallback normalization."""
|
||||||
# Prefer built-in conversion if available
|
try:
|
||||||
if hasattr(tool, "to_openai_tool"):
|
t = convert_to_openai_tool(tool)
|
||||||
try:
|
fn = t.get("function", {})
|
||||||
return tool.to_openai_tool() # type: ignore[attr-defined]
|
fn["description"] = (fn.get("description") or "")[:1000]
|
||||||
except Exception:
|
fn["parameters"] = _normalize_params(fn.get("parameters"))
|
||||||
pass
|
return {"type": "function", "function": fn}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
name = getattr(tool, "name", None) or getattr(tool, "__name__", "tool")
|
name = getattr(tool, "name", None) or getattr(tool, "__name__", "tool")
|
||||||
description = getattr(tool, "description", "")
|
desc = getattr(tool, "description", "") or ""
|
||||||
# Attempt to get a JSON schema
|
|
||||||
params: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
|
||||||
schema = None
|
schema = None
|
||||||
if hasattr(tool, "args_schema") and getattr(tool, "args_schema") is not None:
|
if getattr(tool, "args_schema", None) is not None:
|
||||||
try:
|
try:
|
||||||
schema = tool.args_schema.schema() # pydantic v1
|
schema = tool.args_schema.model_json_schema()
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
schema = tool.args_schema.model_json_schema() # pydantic v2
|
schema = tool.args_schema.schema()
|
||||||
except Exception:
|
except Exception:
|
||||||
schema = None
|
schema = None
|
||||||
if schema is None and hasattr(tool, "args"):
|
if schema is None and hasattr(tool, "args"):
|
||||||
@@ -110,16 +134,8 @@ def _convert_tool(tool: Any) -> Dict[str, Any]:
|
|||||||
schema = tool.args
|
schema = tool.args
|
||||||
except Exception:
|
except Exception:
|
||||||
schema = None
|
schema = None
|
||||||
if isinstance(schema, dict):
|
params = _normalize_params(schema)
|
||||||
params = schema
|
return {"type": "function", "function": {"name": name, "description": desc[:1000], "parameters": params}}
|
||||||
return {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": name,
|
|
||||||
"description": description,
|
|
||||||
"parameters": params,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_tools(tools: Optional[List[Any]]) -> Optional[List[Dict[str, Any]]]:
|
def _convert_tools(tools: Optional[List[Any]]) -> Optional[List[Dict[str, Any]]]:
|
||||||
@@ -272,6 +288,7 @@ class ChatQwenOpenAICompat(BaseChatModel):
|
|||||||
def _request(self, payload: Dict[str, Any], *, stream: bool = False) -> Any:
|
def _request(self, payload: Dict[str, Any], *, stream: bool = False) -> Any:
|
||||||
try:
|
try:
|
||||||
import httpx
|
import httpx
|
||||||
|
import sys as _sys
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError("httpx is required for ChatQwenOpenAICompat. Install via `pip install httpx`. ") from e
|
raise RuntimeError("httpx is required for ChatQwenOpenAICompat. Install via `pip install httpx`. ") from e
|
||||||
|
|
||||||
@@ -280,24 +297,69 @@ class ChatQwenOpenAICompat(BaseChatModel):
|
|||||||
auth_scheme = _env("QWEN_AUTH_SCHEME", "Bearer")
|
auth_scheme = _env("QWEN_AUTH_SCHEME", "Bearer")
|
||||||
auth_value = f"{auth_scheme} {self._api_key}" if auth_scheme else self._api_key
|
auth_value = f"{auth_scheme} {self._api_key}" if auth_scheme else self._api_key
|
||||||
headers = {auth_header: auth_value, "Content-Type": "application/json"}
|
headers = {auth_header: auth_value, "Content-Type": "application/json"}
|
||||||
|
|
||||||
base = self._base_url.rstrip("/")
|
base = self._base_url.rstrip("/")
|
||||||
# Accept either a root base_url (ending with /v1) or a full endpoint that already includes /chat/completions
|
url = base if "chat/completions" in base else base + "/chat/completions"
|
||||||
if "chat/completions" in base:
|
|
||||||
url = base
|
# honor system proxy by default; disable via QWEN_HTTP_TRUST_ENV=0
|
||||||
else:
|
|
||||||
url = base + "/chat/completions"
|
|
||||||
# Proxy behavior: honor env by default; allow disabling via QWEN_HTTP_TRUST_ENV=0
|
|
||||||
trust_env = _env("QWEN_HTTP_TRUST_ENV", "1") != "0"
|
trust_env = _env("QWEN_HTTP_TRUST_ENV", "1") != "0"
|
||||||
client = httpx.Client(timeout=self._timeout, trust_env=trust_env)
|
|
||||||
if _env("QWEN_DEBUG") == "1":
|
if _env("QWEN_DEBUG") == "1":
|
||||||
import sys as _sys
|
|
||||||
print(f"[qwen] POST {url} timeout={self._timeout} trust_env={trust_env}", file=_sys.stderr)
|
print(f"[qwen] POST {url} timeout={self._timeout} trust_env={trust_env}", file=_sys.stderr)
|
||||||
|
|
||||||
|
# ===== DEBUG: dump payload & headers (API key redacted) =====
|
||||||
|
if _env("QWEN_DEBUG_BODY") == "1":
|
||||||
|
try:
|
||||||
|
# 打印请求体
|
||||||
|
print("[qwen] payload =\n" + json.dumps(payload, ensure_ascii=False, indent=2), file=_sys.stderr)
|
||||||
|
except Exception:
|
||||||
|
print("[qwen] payload = <non-serializable>", file=_sys.stderr)
|
||||||
|
try:
|
||||||
|
# 打印打码后的 headers
|
||||||
|
redacted = dict(headers)
|
||||||
|
if auth_header in redacted and isinstance(redacted[auth_header], str):
|
||||||
|
# 只保留认证方案,密钥打码
|
||||||
|
val = redacted[auth_header]
|
||||||
|
if " " in val:
|
||||||
|
scheme, _token = val.split(" ", 1)
|
||||||
|
redacted[auth_header] = f"{scheme} ***REDACTED***"
|
||||||
|
else:
|
||||||
|
redacted[auth_header] = "***REDACTED***"
|
||||||
|
print("[qwen] headers = " + json.dumps(redacted, ensure_ascii=False), file=_sys.stderr)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# ===========================================================
|
||||||
|
|
||||||
|
client = httpx.Client(timeout=self._timeout, trust_env=trust_env)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
# 流式时也会先把 payload 打出来(上面已做),这里直接返回流
|
||||||
return client.stream("POST", url, headers=headers, json=payload)
|
return client.stream("POST", url, headers=headers, json=payload)
|
||||||
|
|
||||||
resp = client.post(url, headers=headers, json=payload)
|
resp = client.post(url, headers=headers, json=payload)
|
||||||
resp.raise_for_status()
|
|
||||||
|
# ===== DEBUG: dump response =====
|
||||||
|
if _env("QWEN_DEBUG_RESP") == "1":
|
||||||
|
print(f"[qwen] status={resp.status_code}", file=_sys.stderr)
|
||||||
|
try:
|
||||||
|
print("[qwen] resp.text =\n" + resp.text, file=_sys.stderr)
|
||||||
|
except Exception:
|
||||||
|
print("[qwen] resp.text = <binary or non-text>", file=_sys.stderr)
|
||||||
|
# =================================
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
# 打印错误体,帮助定位 4xx/5xx 的具体原因
|
||||||
|
if _env("QWEN_DEBUG_RESP", "1") == "1":
|
||||||
|
try:
|
||||||
|
print(f"[qwen] ERROR status={e.response.status_code}\n{e.response.text}", file=_sys.stderr)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
def _build_payload(self, messages: List["BaseMessage"], **kwargs: Any) -> Dict[str, Any]:
|
def _build_payload(self, messages: List["BaseMessage"], **kwargs: Any) -> Dict[str, Any]:
|
||||||
payload: Dict[str, Any] = {
|
payload: Dict[str, Any] = {
|
||||||
"model": self._model,
|
"model": self._model,
|
||||||
@@ -308,13 +370,27 @@ class ChatQwenOpenAICompat(BaseChatModel):
|
|||||||
if self._max_tokens is not None:
|
if self._max_tokens is not None:
|
||||||
payload["max_tokens"] = self._max_tokens
|
payload["max_tokens"] = self._max_tokens
|
||||||
if self._tools:
|
if self._tools:
|
||||||
payload["tools"] = self._tools
|
safe_tools: List[Dict[str, Any]] = []
|
||||||
|
for t in self._tools:
|
||||||
|
if isinstance(t, dict) and t.get("type") == "function" and isinstance(t.get("function"), dict):
|
||||||
|
fn = deepcopy(t["function"]) # type: ignore
|
||||||
|
fn["description"] = (fn.get("description") or "")[:1000]
|
||||||
|
fn["parameters"] = _normalize_params(fn.get("parameters"))
|
||||||
|
safe_tools.append({"type": "function", "function": fn})
|
||||||
|
else:
|
||||||
|
safe_tools.append(_convert_tool(t))
|
||||||
|
payload["tools"] = safe_tools
|
||||||
if self._tool_choice is not None:
|
if self._tool_choice is not None:
|
||||||
payload["tool_choice"] = self._tool_choice
|
payload["tool_choice"] = self._tool_choice
|
||||||
if self._extra_body:
|
if self._extra_body:
|
||||||
payload.update(self._extra_body)
|
payload.update(self._extra_body)
|
||||||
if kwargs:
|
if kwargs:
|
||||||
payload.update(kwargs)
|
allowed = {
|
||||||
|
"tool_choice","temperature","top_p","max_tokens","stop","stream","n","logprobs","logit_bias","user","presence_penalty","frequency_penalty","seed","parallel_tool_calls","response_format"
|
||||||
|
}
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if k in allowed and v is not None:
|
||||||
|
payload[k] = v
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
# -- Non-streaming generate --
|
# -- Non-streaming generate --
|
||||||
|
|||||||
Reference in New Issue
Block a user