Files
openharmony-mlx/tests/peek_scales_v2.py

49 lines
1.6 KiB
Python

# tests/peek_scales_v2.py
import os, json, sys
from safetensors import safe_open
def open_any(srcdir):
single = os.path.join(srcdir, "model.safetensors")
index = os.path.join(srcdir, "model.safetensors.index.json")
if os.path.exists(single):
f = safe_open(single, framework="pt", device="cpu")
return [f], lambda k: f.get_tensor(k), list(f.keys())
wm = json.load(open(index))["weight_map"]
files = sorted(set(wm.values()))
opened = [safe_open(os.path.join(srcdir, fp), framework="pt", device="cpu") for fp in files]
keys = list(wm.keys())
def get(k):
for f in opened:
if k in f.keys(): return f.get_tensor(k)
raise KeyError(k)
return opened, get, keys
def main(srcdir):
opened, get, keys = open_any(srcdir)
patterns = [
"block.{n}.mlp.mlp1_weight.scales", # openharmony-mlx 原生
"block.{n}.mlp.mlp2_weight.scales",
"model.layers.{n}.mlp.experts.gate_up_proj_scales", # Jinx 命名
"model.layers.{n}.mlp.experts.down_proj_scales",
]
import torch
mx = []
for n in range(0, 64): # 够用
for pat in patterns:
k = pat.format(n=n)
try:
t = get(k)
except Exception:
continue
v = float(t.max().item())
mx.append(v)
print(k, "max=", v)
if mx:
m = max(mx)
print("\nGLOBAL MAX SCALE:", m, " (m + 14 =", m+14, ")")
if m + 14 >= 256:
print("⚠️ 会发生 uint8 溢出,必须用 int16+clamp 的写法。")
if __name__ == "__main__":
main(sys.argv[1])