Files
llm-gguf-quant-template/scripts/prepare_calib_data.py
2026-03-02 23:22:33 +08:00

196 lines
5.7 KiB
Python

#!/usr/bin/env python3
"""
Prepare calibration_data_v5_rc_code.txt with exact composition:
- base: 1152 blocks from calibration_data_v5_rc.txt
- code: 2000 blocks from QuixiAI/Code-74k-ShareGPT-Vicuna
- pref: 1000 blocks from alvarobartt/openhermes-preferences-coding
"""
from __future__ import annotations
import argparse
import random
import re
import subprocess
import sys
from pathlib import Path
from datasets import load_dataset
BASE_URL = (
"https://raw.githubusercontent.com/ggerganov/llama.cpp/master/"
"examples/calibration/calibration_data.txt"
)
BLOCK_SPLIT_RE = re.compile(r"\n\s*\n")
SCRIPT_DIR = Path(__file__).resolve().parent
ROOT_DIR = SCRIPT_DIR.parent
def split_blocks(text: str) -> list[str]:
blocks = [b.strip() for b in BLOCK_SPLIT_RE.split(text) if b.strip()]
return blocks
def read_blocks(path: Path) -> list[str]:
return split_blocks(path.read_text(encoding="utf-8", errors="ignore"))
def write_blocks(path: Path, blocks: list[str]) -> None:
path.write_text("\n\n".join(blocks).strip() + "\n", encoding="utf-8")
def ensure_base_file(path: Path) -> None:
if path.exists():
return
cmd = ["wget", BASE_URL, "-O", str(path)]
print("Downloading base calibration file:")
print(" ", " ".join(cmd))
subprocess.run(cmd, check=True)
def pick_blocks(blocks: list[str], target: int, seed: int) -> list[str]:
if len(blocks) < target:
raise ValueError(f"Need {target} blocks but only got {len(blocks)}.")
rng = random.Random(seed)
idxs = list(range(len(blocks)))
rng.shuffle(idxs)
return [blocks[i] for i in idxs[:target]]
def build_code74k_blocks(target: int, seed: int) -> list[str]:
ds = load_dataset("QuixiAI/Code-74k-ShareGPT-Vicuna", split="train")
rows = list(range(len(ds)))
rng = random.Random(seed)
rng.shuffle(rows)
out: list[str] = []
for i in rows:
conv = ds[i].get("conversations") or []
parts = []
for msg in conv:
value = (msg.get("value") or "").strip()
if value:
parts.append(value)
if parts:
out.append("\n".join(parts))
if len(out) >= target:
break
if len(out) < target:
raise RuntimeError(
f"Code-74k yielded only {len(out)} valid blocks, target is {target}."
)
return out
def build_openhermes_blocks(target: int, seed: int) -> list[str]:
ds = load_dataset("alvarobartt/openhermes-preferences-coding", split="train")
rows = list(range(len(ds)))
rng = random.Random(seed + 1)
rng.shuffle(rows)
out: list[str] = []
for i in rows:
chosen = ds[i].get("chosen") or []
parts = []
for msg in chosen:
value = (msg.get("content") or "").strip()
if value:
parts.append(value)
if parts:
out.append("\n".join(parts))
if len(out) >= target:
break
if len(out) < target:
raise RuntimeError(
f"OpenHermes yielded only {len(out)} valid blocks, target is {target}."
)
return out
def ensure_cached_blocks(
cache_path: Path,
target: int,
build_fn,
seed: int,
) -> list[str]:
if cache_path.exists():
cached = read_blocks(cache_path)
if len(cached) >= target:
return cached[:target]
print(
f"{cache_path} has {len(cached)} blocks (< {target}), rebuilding from source."
)
blocks = build_fn(target, seed)
cache_path.parent.mkdir(parents=True, exist_ok=True)
write_blocks(cache_path, blocks)
return blocks
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--base-file", default="calibration/calibration_data_v5_rc.txt")
parser.add_argument("--output", default="calibration/calibration_data_v5_rc_code.txt")
parser.add_argument("--data-dir", default="calibration/sources")
parser.add_argument("--force-refresh", action="store_true")
args = parser.parse_args()
def resolve_path(path_text: str) -> Path:
p = Path(path_text)
if p.is_absolute():
return p
return ROOT_DIR / p
base_file = resolve_path(args.base_file)
output_file = resolve_path(args.output)
data_dir = resolve_path(args.data_dir)
code_cache = data_dir / "code74k_2000.txt"
openhermes_cache = data_dir / "openhermes_coding_chosen_1000.txt"
if args.force_refresh:
for p in [code_cache, openhermes_cache]:
if p.exists():
p.unlink()
ensure_base_file(base_file)
base_blocks_all = read_blocks(base_file)
base_blocks = pick_blocks(base_blocks_all, target=1152, seed=args.seed)
code_blocks = ensure_cached_blocks(
cache_path=code_cache,
target=2000,
build_fn=build_code74k_blocks,
seed=args.seed,
)
openhermes_blocks = ensure_cached_blocks(
cache_path=openhermes_cache,
target=1000,
build_fn=build_openhermes_blocks,
seed=args.seed,
)
merged = base_blocks + code_blocks + openhermes_blocks
write_blocks(output_file, merged)
print("Done.")
print(f"base blocks: {len(base_blocks)} ({base_file})")
print(f"code blocks: {len(code_blocks)} (QuixiAI/Code-74k-ShareGPT-Vicuna)")
print(
"openhermes blocks: "
f"{len(openhermes_blocks)} (alvarobartt/openhermes-preferences-coding)"
)
print(f"total blocks: {len(merged)}")
print(f"output: {output_file}")
return 0
if __name__ == "__main__":
try:
raise SystemExit(main())
except subprocess.CalledProcessError as exc:
print(f"Command failed with exit code {exc.returncode}", file=sys.stderr)
raise