from __future__ import annotations import json import os from pathlib import Path import re import shlex import subprocess from typing import Any from .models import QueryType class QMDCommandError(RuntimeError): def __init__(self, args: list[str], returncode: int, stdout: str, stderr: str) -> None: cmd = " ".join(shlex.quote(x) for x in args) message = f"qmd command failed ({returncode}): {cmd}\nstdout:\n{stdout}\nstderr:\n{stderr}" super().__init__(message) self.args_list = args self.returncode = returncode self.stdout = stdout self.stderr = stderr class QMDClient: def __init__( self, qmd_binary: str, timeout_seconds: int, xdg_cache_home: Path, xdg_config_home: Path, ) -> None: self.qmd_binary = qmd_binary self.timeout_seconds = timeout_seconds self.xdg_cache_home = xdg_cache_home self.xdg_config_home = xdg_config_home def ensure_collection(self, index_name: str, collection_name: str, workspace_path: Path) -> bool: existing = self.list_collections(index_name=index_name) if collection_name in existing: return False self._run( [ self.qmd_binary, "--index", index_name, "collection", "add", str(workspace_path), "--name", collection_name, ] ) return True def list_collections(self, index_name: str) -> set[str]: output = self._run([self.qmd_binary, "--index", index_name, "collection", "list"]) names: set[str] = set() for line in output.splitlines(): match = re.match(r"^([A-Za-z0-9_.-]+) \(qmd://", line.strip()) if match: names.add(match.group(1)) return names def update_workspace(self, index_name: str) -> None: self._run([self.qmd_binary, "--index", index_name, "update"]) def embed_workspace_if_needed(self, index_name: str, should_embed: bool) -> bool: if not should_embed: return False self._run([self.qmd_binary, "--index", index_name, "embed"]) return True def run_query( self, *, index_name: str, collection_name: str, query_type: QueryType, query: str, n: int, ) -> tuple[Any, str]: command_name = self._query_command_name(query_type) top_k = max(n, 10) if query_type == QueryType.deep_search else n args = self._query_args( index_name=index_name, command_name=command_name, query=query, top_k=top_k, collection_name=collection_name, ) try: output = self._run(args) return self._parse_output(output), command_name except QMDCommandError: # Query/deep_search may fail when LLM stack is not ready; keep API available. if query_type not in {QueryType.query, QueryType.deep_search}: raise fallback_command = "search" fallback_args = self._query_args( index_name=index_name, command_name=fallback_command, query=query, top_k=top_k, collection_name=collection_name, ) fallback_output = self._run(fallback_args) return self._parse_output(fallback_output), fallback_command def _query_command_name(self, query_type: QueryType) -> str: if query_type in {QueryType.query, QueryType.deep_search}: # Keep request latency bounded when LLM stack is not prewarmed. return "search" return query_type.value def _parse_output(self, output: str) -> Any: stripped = output.strip() if not stripped: return [] try: return json.loads(stripped) except json.JSONDecodeError: return {"raw": stripped} def _query_args( self, *, index_name: str, command_name: str, query: str, top_k: int, collection_name: str, ) -> list[str]: return [ self.qmd_binary, "--index", index_name, command_name, query, "--json", "-n", str(top_k), "-c", collection_name, ] def _run(self, args: list[str]) -> str: env = { **os.environ, "XDG_CACHE_HOME": str(self.xdg_cache_home), "XDG_CONFIG_HOME": str(self.xdg_config_home), } try: proc = subprocess.run( args, check=False, capture_output=True, text=True, env=env, timeout=self.timeout_seconds, ) except subprocess.TimeoutExpired as exc: stdout = exc.stdout if isinstance(exc.stdout, str) else "" stderr = exc.stderr if isinstance(exc.stderr, str) else "" raise QMDCommandError( args=args, returncode=124, stdout=stdout, stderr=f"{stderr}\nTimeout after {self.timeout_seconds}s", ) from exc if proc.returncode != 0: raise QMDCommandError(args=args, returncode=proc.returncode, stdout=proc.stdout, stderr=proc.stderr) return proc.stdout