176 lines
5.4 KiB
Python
176 lines
5.4 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
import re
|
|
import shlex
|
|
import subprocess
|
|
from typing import Any
|
|
|
|
from .models import QueryType
|
|
|
|
|
|
class QMDCommandError(RuntimeError):
|
|
def __init__(self, args: list[str], returncode: int, stdout: str, stderr: str) -> None:
|
|
cmd = " ".join(shlex.quote(x) for x in args)
|
|
message = f"qmd command failed ({returncode}): {cmd}\nstdout:\n{stdout}\nstderr:\n{stderr}"
|
|
super().__init__(message)
|
|
self.args_list = args
|
|
self.returncode = returncode
|
|
self.stdout = stdout
|
|
self.stderr = stderr
|
|
|
|
|
|
class QMDClient:
|
|
def __init__(
|
|
self,
|
|
qmd_binary: str,
|
|
timeout_seconds: int,
|
|
xdg_cache_home: Path,
|
|
xdg_config_home: Path,
|
|
) -> None:
|
|
self.qmd_binary = qmd_binary
|
|
self.timeout_seconds = timeout_seconds
|
|
self.xdg_cache_home = xdg_cache_home
|
|
self.xdg_config_home = xdg_config_home
|
|
|
|
def ensure_collection(self, index_name: str, collection_name: str, workspace_path: Path) -> bool:
|
|
existing = self.list_collections(index_name=index_name)
|
|
if collection_name in existing:
|
|
return False
|
|
|
|
self._run(
|
|
[
|
|
self.qmd_binary,
|
|
"--index",
|
|
index_name,
|
|
"collection",
|
|
"add",
|
|
str(workspace_path),
|
|
"--name",
|
|
collection_name,
|
|
]
|
|
)
|
|
return True
|
|
|
|
def list_collections(self, index_name: str) -> set[str]:
|
|
output = self._run([self.qmd_binary, "--index", index_name, "collection", "list"])
|
|
names: set[str] = set()
|
|
for line in output.splitlines():
|
|
match = re.match(r"^([A-Za-z0-9_.-]+) \(qmd://", line.strip())
|
|
if match:
|
|
names.add(match.group(1))
|
|
return names
|
|
|
|
def update_workspace(self, index_name: str) -> None:
|
|
self._run([self.qmd_binary, "--index", index_name, "update"])
|
|
|
|
def embed_workspace_if_needed(self, index_name: str, should_embed: bool) -> bool:
|
|
if not should_embed:
|
|
return False
|
|
self._run([self.qmd_binary, "--index", index_name, "embed"])
|
|
return True
|
|
|
|
def run_query(
|
|
self,
|
|
*,
|
|
index_name: str,
|
|
collection_name: str,
|
|
query_type: QueryType,
|
|
query: str,
|
|
n: int,
|
|
) -> tuple[Any, str]:
|
|
command_name = self._query_command_name(query_type)
|
|
top_k = max(n, 10) if query_type == QueryType.deep_search else n
|
|
args = self._query_args(
|
|
index_name=index_name,
|
|
command_name=command_name,
|
|
query=query,
|
|
top_k=top_k,
|
|
collection_name=collection_name,
|
|
)
|
|
try:
|
|
output = self._run(args)
|
|
return self._parse_output(output), command_name
|
|
except QMDCommandError:
|
|
# Query/deep_search may fail when LLM stack is not ready; keep API available.
|
|
if query_type not in {QueryType.query, QueryType.deep_search}:
|
|
raise
|
|
fallback_command = "search"
|
|
fallback_args = self._query_args(
|
|
index_name=index_name,
|
|
command_name=fallback_command,
|
|
query=query,
|
|
top_k=top_k,
|
|
collection_name=collection_name,
|
|
)
|
|
fallback_output = self._run(fallback_args)
|
|
return self._parse_output(fallback_output), fallback_command
|
|
|
|
def _query_command_name(self, query_type: QueryType) -> str:
|
|
if query_type in {QueryType.query, QueryType.deep_search}:
|
|
# Keep request latency bounded when LLM stack is not prewarmed.
|
|
return "search"
|
|
return query_type.value
|
|
|
|
def _parse_output(self, output: str) -> Any:
|
|
stripped = output.strip()
|
|
if not stripped:
|
|
return []
|
|
try:
|
|
return json.loads(stripped)
|
|
except json.JSONDecodeError:
|
|
return {"raw": stripped}
|
|
|
|
def _query_args(
|
|
self,
|
|
*,
|
|
index_name: str,
|
|
command_name: str,
|
|
query: str,
|
|
top_k: int,
|
|
collection_name: str,
|
|
) -> list[str]:
|
|
return [
|
|
self.qmd_binary,
|
|
"--index",
|
|
index_name,
|
|
command_name,
|
|
query,
|
|
"--json",
|
|
"-n",
|
|
str(top_k),
|
|
"-c",
|
|
collection_name,
|
|
]
|
|
|
|
def _run(self, args: list[str]) -> str:
|
|
env = {
|
|
**os.environ,
|
|
"XDG_CACHE_HOME": str(self.xdg_cache_home),
|
|
"XDG_CONFIG_HOME": str(self.xdg_config_home),
|
|
}
|
|
try:
|
|
proc = subprocess.run(
|
|
args,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
env=env,
|
|
timeout=self.timeout_seconds,
|
|
)
|
|
except subprocess.TimeoutExpired as exc:
|
|
stdout = exc.stdout if isinstance(exc.stdout, str) else ""
|
|
stderr = exc.stderr if isinstance(exc.stderr, str) else ""
|
|
raise QMDCommandError(
|
|
args=args,
|
|
returncode=124,
|
|
stdout=stdout,
|
|
stderr=f"{stderr}\nTimeout after {self.timeout_seconds}s",
|
|
) from exc
|
|
|
|
if proc.returncode != 0:
|
|
raise QMDCommandError(args=args, returncode=proc.returncode, stdout=proc.stdout, stderr=proc.stderr)
|
|
return proc.stdout
|