188 lines
5.4 KiB
Python
188 lines
5.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Prepare calibration_data_v5_rc_code.txt with exact composition:
|
|
- base: 1152 blocks from calibration_data_v5_rc.txt
|
|
- code: 2000 blocks from QuixiAI/Code-74k-ShareGPT-Vicuna
|
|
- pref: 1000 blocks from alvarobartt/openhermes-preferences-coding
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import random
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
from datasets import load_dataset
|
|
|
|
BASE_URL = (
|
|
"https://raw.githubusercontent.com/ggerganov/llama.cpp/master/"
|
|
"examples/calibration/calibration_data.txt"
|
|
)
|
|
BLOCK_SPLIT_RE = re.compile(r"\n\s*\n")
|
|
|
|
|
|
def split_blocks(text: str) -> list[str]:
|
|
blocks = [b.strip() for b in BLOCK_SPLIT_RE.split(text) if b.strip()]
|
|
return blocks
|
|
|
|
|
|
def read_blocks(path: Path) -> list[str]:
|
|
return split_blocks(path.read_text(encoding="utf-8", errors="ignore"))
|
|
|
|
|
|
def write_blocks(path: Path, blocks: list[str]) -> None:
|
|
path.write_text("\n\n".join(blocks).strip() + "\n", encoding="utf-8")
|
|
|
|
|
|
def ensure_base_file(path: Path) -> None:
|
|
if path.exists():
|
|
return
|
|
cmd = ["wget", BASE_URL, "-O", str(path)]
|
|
print("Downloading base calibration file:")
|
|
print(" ", " ".join(cmd))
|
|
subprocess.run(cmd, check=True)
|
|
|
|
|
|
def pick_blocks(blocks: list[str], target: int, seed: int) -> list[str]:
|
|
if len(blocks) < target:
|
|
raise ValueError(f"Need {target} blocks but only got {len(blocks)}.")
|
|
rng = random.Random(seed)
|
|
idxs = list(range(len(blocks)))
|
|
rng.shuffle(idxs)
|
|
return [blocks[i] for i in idxs[:target]]
|
|
|
|
|
|
def build_code74k_blocks(target: int, seed: int) -> list[str]:
|
|
ds = load_dataset("QuixiAI/Code-74k-ShareGPT-Vicuna", split="train")
|
|
rows = list(range(len(ds)))
|
|
rng = random.Random(seed)
|
|
rng.shuffle(rows)
|
|
|
|
out: list[str] = []
|
|
for i in rows:
|
|
conv = ds[i].get("conversations") or []
|
|
parts = []
|
|
for msg in conv:
|
|
value = (msg.get("value") or "").strip()
|
|
if value:
|
|
parts.append(value)
|
|
if parts:
|
|
out.append("\n".join(parts))
|
|
if len(out) >= target:
|
|
break
|
|
|
|
if len(out) < target:
|
|
raise RuntimeError(
|
|
f"Code-74k yielded only {len(out)} valid blocks, target is {target}."
|
|
)
|
|
return out
|
|
|
|
|
|
def build_openhermes_blocks(target: int, seed: int) -> list[str]:
|
|
ds = load_dataset("alvarobartt/openhermes-preferences-coding", split="train")
|
|
rows = list(range(len(ds)))
|
|
rng = random.Random(seed + 1)
|
|
rng.shuffle(rows)
|
|
|
|
out: list[str] = []
|
|
for i in rows:
|
|
chosen = ds[i].get("chosen") or []
|
|
parts = []
|
|
for msg in chosen:
|
|
value = (msg.get("content") or "").strip()
|
|
if value:
|
|
parts.append(value)
|
|
if parts:
|
|
out.append("\n".join(parts))
|
|
if len(out) >= target:
|
|
break
|
|
|
|
if len(out) < target:
|
|
raise RuntimeError(
|
|
f"OpenHermes yielded only {len(out)} valid blocks, target is {target}."
|
|
)
|
|
return out
|
|
|
|
|
|
def ensure_cached_blocks(
|
|
cache_path: Path,
|
|
target: int,
|
|
build_fn,
|
|
seed: int,
|
|
) -> list[str]:
|
|
if cache_path.exists():
|
|
cached = read_blocks(cache_path)
|
|
if len(cached) >= target:
|
|
return cached[:target]
|
|
print(
|
|
f"{cache_path} has {len(cached)} blocks (< {target}), rebuilding from source."
|
|
)
|
|
|
|
blocks = build_fn(target, seed)
|
|
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
|
write_blocks(cache_path, blocks)
|
|
return blocks
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--base-file", default="calibration_data_v5_rc.txt")
|
|
parser.add_argument("--output", default="calibration_data_v5_rc_code.txt")
|
|
parser.add_argument("--data-dir", default="data")
|
|
parser.add_argument("--force-refresh", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
base_file = Path(args.base_file)
|
|
output_file = Path(args.output)
|
|
data_dir = Path(args.data_dir)
|
|
code_cache = data_dir / "code74k_2000.txt"
|
|
openhermes_cache = data_dir / "openhermes_coding_chosen_1000.txt"
|
|
|
|
if args.force_refresh:
|
|
for p in [code_cache, openhermes_cache]:
|
|
if p.exists():
|
|
p.unlink()
|
|
|
|
ensure_base_file(base_file)
|
|
base_blocks_all = read_blocks(base_file)
|
|
base_blocks = pick_blocks(base_blocks_all, target=1152, seed=args.seed)
|
|
|
|
code_blocks = ensure_cached_blocks(
|
|
cache_path=code_cache,
|
|
target=2000,
|
|
build_fn=build_code74k_blocks,
|
|
seed=args.seed,
|
|
)
|
|
openhermes_blocks = ensure_cached_blocks(
|
|
cache_path=openhermes_cache,
|
|
target=1000,
|
|
build_fn=build_openhermes_blocks,
|
|
seed=args.seed,
|
|
)
|
|
|
|
merged = base_blocks + code_blocks + openhermes_blocks
|
|
write_blocks(output_file, merged)
|
|
|
|
print("Done.")
|
|
print(f"base blocks: {len(base_blocks)} ({base_file})")
|
|
print(f"code blocks: {len(code_blocks)} (QuixiAI/Code-74k-ShareGPT-Vicuna)")
|
|
print(
|
|
"openhermes blocks: "
|
|
f"{len(openhermes_blocks)} (alvarobartt/openhermes-preferences-coding)"
|
|
)
|
|
print(f"total blocks: {len(merged)}")
|
|
print(f"output: {output_file}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
raise SystemExit(main())
|
|
except subprocess.CalledProcessError as exc:
|
|
print(f"Command failed with exit code {exc.returncode}", file=sys.stderr)
|
|
raise
|