53 lines
1.8 KiB
Python
53 lines
1.8 KiB
Python
# peek_scales.py —— 兼容分片 safetensors
|
||
import json, os, sys
|
||
from safetensors import safe_open
|
||
|
||
def iter_keys(srcdir):
|
||
single = os.path.join(srcdir, "model.safetensors")
|
||
index = os.path.join(srcdir, "model.safetensors.index.json")
|
||
if os.path.exists(single):
|
||
with safe_open(single, framework="pt", device="cpu") as f:
|
||
yield f, list(f.keys())
|
||
else:
|
||
idx = json.load(open(index))
|
||
wm = idx["weight_map"]
|
||
opened = {}
|
||
from contextlib import ExitStack
|
||
with ExitStack() as stack:
|
||
for rel in set(wm.values()):
|
||
opened[rel] = stack.enter_context(safe_open(os.path.join(srcdir, rel), framework="pt", device="cpu"))
|
||
yield opened, list(wm.keys())
|
||
|
||
def peek(srcdir):
|
||
opened, keys = next(iter_keys(srcdir))
|
||
def get(k):
|
||
if isinstance(opened, dict):
|
||
# 分片模式
|
||
# 找到该 key 对应文件;这里简单起见直接遍历
|
||
for f in opened.values():
|
||
if k in f.keys():
|
||
return f.get_tensor(k)
|
||
raise KeyError(k)
|
||
else:
|
||
return opened.get_tensor(k)
|
||
|
||
import torch
|
||
mx = []
|
||
for n in range(0, 1000): # 粗看前 1000 层以内的命名
|
||
for which in ("mlp1_weight.scales", "mlp2_weight.scales"):
|
||
k = f"block.{n}.mlp.{which}"
|
||
try:
|
||
t = get(k)
|
||
mx.append(t.max().item())
|
||
print(k, "max=", float(mx[-1]))
|
||
except Exception:
|
||
pass
|
||
if mx:
|
||
m = max(mx)
|
||
print("\nGLOBAL MAX SCALE:", m, " (m + 14 =", m+14, ")")
|
||
if m >= 241:
|
||
print("⚠️ 警告:m+14 可能溢出 uint8!请用 int16+clamp 写回。")
|
||
|
||
if __name__ == "__main__":
|
||
peek(sys.argv[1])
|