49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
# tests/peek_scales_v2.py
|
|
import os, json, sys
|
|
from safetensors import safe_open
|
|
|
|
def open_any(srcdir):
|
|
single = os.path.join(srcdir, "model.safetensors")
|
|
index = os.path.join(srcdir, "model.safetensors.index.json")
|
|
if os.path.exists(single):
|
|
f = safe_open(single, framework="pt", device="cpu")
|
|
return [f], lambda k: f.get_tensor(k), list(f.keys())
|
|
wm = json.load(open(index))["weight_map"]
|
|
files = sorted(set(wm.values()))
|
|
opened = [safe_open(os.path.join(srcdir, fp), framework="pt", device="cpu") for fp in files]
|
|
keys = list(wm.keys())
|
|
def get(k):
|
|
for f in opened:
|
|
if k in f.keys(): return f.get_tensor(k)
|
|
raise KeyError(k)
|
|
return opened, get, keys
|
|
|
|
def main(srcdir):
|
|
opened, get, keys = open_any(srcdir)
|
|
patterns = [
|
|
"block.{n}.mlp.mlp1_weight.scales", # openharmony-mlx 原生
|
|
"block.{n}.mlp.mlp2_weight.scales",
|
|
"model.layers.{n}.mlp.experts.gate_up_proj_scales", # Jinx 命名
|
|
"model.layers.{n}.mlp.experts.down_proj_scales",
|
|
]
|
|
import torch
|
|
mx = []
|
|
for n in range(0, 64): # 够用
|
|
for pat in patterns:
|
|
k = pat.format(n=n)
|
|
try:
|
|
t = get(k)
|
|
except Exception:
|
|
continue
|
|
v = float(t.max().item())
|
|
mx.append(v)
|
|
print(k, "max=", v)
|
|
if mx:
|
|
m = max(mx)
|
|
print("\nGLOBAL MAX SCALE:", m, " (m + 14 =", m+14, ")")
|
|
if m + 14 >= 256:
|
|
print("⚠️ 会发生 uint8 溢出,必须用 int16+clamp 的写法。")
|
|
|
|
if __name__ == "__main__":
|
|
main(sys.argv[1])
|