增加权重转化时候的检查脚本

This commit is contained in:
2025-10-08 19:59:25 +08:00
parent 9974fc7a00
commit 5c5838605c
5 changed files with 233 additions and 0 deletions

52
tests/peek_scales.py Normal file
View File

@@ -0,0 +1,52 @@
# peek_scales.py —— 兼容分片 safetensors
import json, os, sys
from safetensors import safe_open
def iter_keys(srcdir):
single = os.path.join(srcdir, "model.safetensors")
index = os.path.join(srcdir, "model.safetensors.index.json")
if os.path.exists(single):
with safe_open(single, framework="pt", device="cpu") as f:
yield f, list(f.keys())
else:
idx = json.load(open(index))
wm = idx["weight_map"]
opened = {}
from contextlib import ExitStack
with ExitStack() as stack:
for rel in set(wm.values()):
opened[rel] = stack.enter_context(safe_open(os.path.join(srcdir, rel), framework="pt", device="cpu"))
yield opened, list(wm.keys())
def peek(srcdir):
opened, keys = next(iter_keys(srcdir))
def get(k):
if isinstance(opened, dict):
# 分片模式
# 找到该 key 对应文件;这里简单起见直接遍历
for f in opened.values():
if k in f.keys():
return f.get_tensor(k)
raise KeyError(k)
else:
return opened.get_tensor(k)
import torch
mx = []
for n in range(0, 1000): # 粗看前 1000 层以内的命名
for which in ("mlp1_weight.scales", "mlp2_weight.scales"):
k = f"block.{n}.mlp.{which}"
try:
t = get(k)
mx.append(t.max().item())
print(k, "max=", float(mx[-1]))
except Exception:
pass
if mx:
m = max(mx)
print("\nGLOBAL MAX SCALE:", m, " (m + 14 =", m+14, ")")
if m >= 241:
print("⚠️ 警告m+14 可能溢出 uint8请用 int16+clamp 写回。")
if __name__ == "__main__":
peek(sys.argv[1])