136 lines
5.4 KiB
Python
136 lines
5.4 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import os
|
|
from pathlib import Path
|
|
import subprocess
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
# Keep module-level app import from trying to write to /data on local tests.
|
|
os.environ.setdefault("WORKSPACES_ROOT", "/tmp/qmd-gateway-tests/workspaces")
|
|
os.environ.setdefault("WORKSPACE_STATE_DIR", "/tmp/qmd-gateway-tests/state")
|
|
os.environ.setdefault("GIT_MIRROR_PATH", "/tmp/qmd-gateway-tests/git-mirror/repo.git")
|
|
os.environ.setdefault("GIT_REMOTE_URL", "/tmp/qmd-gateway-tests/remote.git")
|
|
|
|
from app.config import Settings
|
|
from app.main import create_app
|
|
from app.models import QueryType
|
|
|
|
|
|
class FakeQMDClient:
|
|
def __init__(self) -> None:
|
|
self.collections: dict[str, dict[str, Path]] = {}
|
|
self.update_calls: list[str] = []
|
|
self.embed_calls: list[str] = []
|
|
|
|
def ensure_collection(self, index_name: str, collection_name: str, workspace_path: Path) -> bool:
|
|
bucket = self.collections.setdefault(index_name, {})
|
|
if collection_name in bucket:
|
|
return False
|
|
bucket[collection_name] = workspace_path
|
|
return True
|
|
|
|
def list_collections(self, index_name: str) -> set[str]:
|
|
return set(self.collections.get(index_name, {}).keys())
|
|
|
|
def update_workspace(self, index_name: str) -> None:
|
|
self.update_calls.append(index_name)
|
|
|
|
def embed_workspace_if_needed(self, index_name: str, should_embed: bool) -> bool:
|
|
if should_embed:
|
|
self.embed_calls.append(index_name)
|
|
return True
|
|
return False
|
|
|
|
def run_query(self, *, index_name: str, collection_name: str, query_type: QueryType, query: str, n: int):
|
|
workspace = self.collections[index_name][collection_name]
|
|
needle = query.lower()
|
|
matches: list[dict[str, Any]] = []
|
|
for path in sorted(workspace.rglob("*.md")):
|
|
text = path.read_text(encoding="utf-8")
|
|
if needle in text.lower() or needle in path.name.lower():
|
|
matches.append({"file": str(path), "snippet": text[:120]})
|
|
return matches[:n], query_type.value
|
|
|
|
|
|
@dataclass
|
|
class TestRepo:
|
|
remote_path: Path
|
|
seed_path: Path
|
|
|
|
def commit_on_branch(self, branch: str, rel_path: str, content: str, message: str) -> str:
|
|
_run(["git", "-C", str(self.seed_path), "checkout", branch])
|
|
file_path = self.seed_path / rel_path
|
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
file_path.write_text(content, encoding="utf-8")
|
|
_run(["git", "-C", str(self.seed_path), "add", rel_path])
|
|
_run(["git", "-C", str(self.seed_path), "commit", "-m", message])
|
|
_run(["git", "-C", str(self.seed_path), "push", "origin", branch])
|
|
commit = _run(["git", "-C", str(self.seed_path), "rev-parse", "HEAD"]).strip()
|
|
return commit
|
|
|
|
|
|
@pytest.fixture()
|
|
def repo(tmp_path: Path) -> TestRepo:
|
|
remote = tmp_path / "remote.git"
|
|
seed = tmp_path / "seed"
|
|
|
|
_run(["git", "init", "--bare", str(remote)])
|
|
_run(["git", "clone", str(remote), str(seed)])
|
|
_run(["git", "-C", str(seed), "config", "user.name", "Test User"])
|
|
_run(["git", "-C", str(seed), "config", "user.email", "test@example.com"])
|
|
|
|
(seed / "README.md").write_text("main branch memory root\n", encoding="utf-8")
|
|
(seed / "docs" / "main-only.md").parent.mkdir(parents=True, exist_ok=True)
|
|
(seed / "docs" / "main-only.md").write_text("alpha-main-signal\n", encoding="utf-8")
|
|
_run(["git", "-C", str(seed), "add", "README.md", "docs/main-only.md"])
|
|
_run(["git", "-C", str(seed), "commit", "-m", "init main"])
|
|
_run(["git", "-C", str(seed), "branch", "-M", "main"])
|
|
_run(["git", "-C", str(seed), "push", "-u", "origin", "main"])
|
|
|
|
_run(["git", "-C", str(seed), "checkout", "-b", "memory/2026-03"])
|
|
(seed / "docs" / "monthly-only.md").write_text(
|
|
"beta-monthly-signal\nmonthly-exclusive-signal\n",
|
|
encoding="utf-8",
|
|
)
|
|
_run(["git", "-C", str(seed), "add", "docs/monthly-only.md"])
|
|
_run(["git", "-C", str(seed), "commit", "-m", "add monthly"])
|
|
_run(["git", "-C", str(seed), "push", "-u", "origin", "memory/2026-03"])
|
|
|
|
_run(["git", "-C", str(seed), "checkout", "main"])
|
|
(seed / "docs" / "main-exclusive.md").write_text("main-exclusive-signal\n", encoding="utf-8")
|
|
_run(["git", "-C", str(seed), "add", "docs/main-exclusive.md"])
|
|
_run(["git", "-C", str(seed), "commit", "-m", "add main exclusive"])
|
|
_run(["git", "-C", str(seed), "push", "origin", "main"])
|
|
|
|
return TestRepo(remote_path=remote, seed_path=seed)
|
|
|
|
|
|
@pytest.fixture()
|
|
def fake_qmd() -> FakeQMDClient:
|
|
return FakeQMDClient()
|
|
|
|
|
|
@pytest.fixture()
|
|
def client(tmp_path: Path, repo: TestRepo, fake_qmd: FakeQMDClient) -> TestClient:
|
|
settings = Settings(
|
|
git_remote_url=str(repo.remote_path),
|
|
git_mirror_path=tmp_path / "git-mirror" / "repo.git",
|
|
workspaces_root=tmp_path / "workspaces",
|
|
workspace_state_dir=tmp_path / "state",
|
|
xdg_cache_home=tmp_path / "cache",
|
|
xdg_config_home=tmp_path / "config",
|
|
)
|
|
app = create_app(settings=settings, qmd_client=fake_qmd)
|
|
return TestClient(app)
|
|
|
|
|
|
def _run(args: list[str]) -> str:
|
|
proc = subprocess.run(args, check=False, capture_output=True, text=True)
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"command failed: {' '.join(args)}\nstdout={proc.stdout}\nstderr={proc.stderr}")
|
|
return proc.stdout
|