Files

136 lines
5.4 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import os
from pathlib import Path
import subprocess
from typing import Any
import pytest
from fastapi.testclient import TestClient
# Keep module-level app import from trying to write to /data on local tests.
os.environ.setdefault("WORKSPACES_ROOT", "/tmp/qmd-gateway-tests/workspaces")
os.environ.setdefault("WORKSPACE_STATE_DIR", "/tmp/qmd-gateway-tests/state")
os.environ.setdefault("GIT_MIRROR_PATH", "/tmp/qmd-gateway-tests/git-mirror/repo.git")
os.environ.setdefault("GIT_REMOTE_URL", "/tmp/qmd-gateway-tests/remote.git")
from app.config import Settings
from app.main import create_app
from app.models import QueryType
class FakeQMDClient:
def __init__(self) -> None:
self.collections: dict[str, dict[str, Path]] = {}
self.update_calls: list[str] = []
self.embed_calls: list[str] = []
def ensure_collection(self, index_name: str, collection_name: str, workspace_path: Path) -> bool:
bucket = self.collections.setdefault(index_name, {})
if collection_name in bucket:
return False
bucket[collection_name] = workspace_path
return True
def list_collections(self, index_name: str) -> set[str]:
return set(self.collections.get(index_name, {}).keys())
def update_workspace(self, index_name: str) -> None:
self.update_calls.append(index_name)
def embed_workspace_if_needed(self, index_name: str, should_embed: bool) -> bool:
if should_embed:
self.embed_calls.append(index_name)
return True
return False
def run_query(self, *, index_name: str, collection_name: str, query_type: QueryType, query: str, n: int):
workspace = self.collections[index_name][collection_name]
needle = query.lower()
matches: list[dict[str, Any]] = []
for path in sorted(workspace.rglob("*.md")):
text = path.read_text(encoding="utf-8")
if needle in text.lower() or needle in path.name.lower():
matches.append({"file": str(path), "snippet": text[:120]})
return matches[:n], query_type.value
@dataclass
class TestRepo:
remote_path: Path
seed_path: Path
def commit_on_branch(self, branch: str, rel_path: str, content: str, message: str) -> str:
_run(["git", "-C", str(self.seed_path), "checkout", branch])
file_path = self.seed_path / rel_path
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(content, encoding="utf-8")
_run(["git", "-C", str(self.seed_path), "add", rel_path])
_run(["git", "-C", str(self.seed_path), "commit", "-m", message])
_run(["git", "-C", str(self.seed_path), "push", "origin", branch])
commit = _run(["git", "-C", str(self.seed_path), "rev-parse", "HEAD"]).strip()
return commit
@pytest.fixture()
def repo(tmp_path: Path) -> TestRepo:
remote = tmp_path / "remote.git"
seed = tmp_path / "seed"
_run(["git", "init", "--bare", str(remote)])
_run(["git", "clone", str(remote), str(seed)])
_run(["git", "-C", str(seed), "config", "user.name", "Test User"])
_run(["git", "-C", str(seed), "config", "user.email", "test@example.com"])
(seed / "README.md").write_text("main branch memory root\n", encoding="utf-8")
(seed / "docs" / "main-only.md").parent.mkdir(parents=True, exist_ok=True)
(seed / "docs" / "main-only.md").write_text("alpha-main-signal\n", encoding="utf-8")
_run(["git", "-C", str(seed), "add", "README.md", "docs/main-only.md"])
_run(["git", "-C", str(seed), "commit", "-m", "init main"])
_run(["git", "-C", str(seed), "branch", "-M", "main"])
_run(["git", "-C", str(seed), "push", "-u", "origin", "main"])
_run(["git", "-C", str(seed), "checkout", "-b", "memory/2026-03"])
(seed / "docs" / "monthly-only.md").write_text(
"beta-monthly-signal\nmonthly-exclusive-signal\n",
encoding="utf-8",
)
_run(["git", "-C", str(seed), "add", "docs/monthly-only.md"])
_run(["git", "-C", str(seed), "commit", "-m", "add monthly"])
_run(["git", "-C", str(seed), "push", "-u", "origin", "memory/2026-03"])
_run(["git", "-C", str(seed), "checkout", "main"])
(seed / "docs" / "main-exclusive.md").write_text("main-exclusive-signal\n", encoding="utf-8")
_run(["git", "-C", str(seed), "add", "docs/main-exclusive.md"])
_run(["git", "-C", str(seed), "commit", "-m", "add main exclusive"])
_run(["git", "-C", str(seed), "push", "origin", "main"])
return TestRepo(remote_path=remote, seed_path=seed)
@pytest.fixture()
def fake_qmd() -> FakeQMDClient:
return FakeQMDClient()
@pytest.fixture()
def client(tmp_path: Path, repo: TestRepo, fake_qmd: FakeQMDClient) -> TestClient:
settings = Settings(
git_remote_url=str(repo.remote_path),
git_mirror_path=tmp_path / "git-mirror" / "repo.git",
workspaces_root=tmp_path / "workspaces",
workspace_state_dir=tmp_path / "state",
xdg_cache_home=tmp_path / "cache",
xdg_config_home=tmp_path / "config",
)
app = create_app(settings=settings, qmd_client=fake_qmd)
return TestClient(app)
def _run(args: list[str]) -> str:
proc = subprocess.run(args, check=False, capture_output=True, text=True)
if proc.returncode != 0:
raise RuntimeError(f"command failed: {' '.join(args)}\nstdout={proc.stdout}\nstderr={proc.stderr}")
return proc.stdout