# tests/peek_scales_v2.py import os, json, sys from safetensors import safe_open def open_any(srcdir): single = os.path.join(srcdir, "model.safetensors") index = os.path.join(srcdir, "model.safetensors.index.json") if os.path.exists(single): f = safe_open(single, framework="pt", device="cpu") return [f], lambda k: f.get_tensor(k), list(f.keys()) wm = json.load(open(index))["weight_map"] files = sorted(set(wm.values())) opened = [safe_open(os.path.join(srcdir, fp), framework="pt", device="cpu") for fp in files] keys = list(wm.keys()) def get(k): for f in opened: if k in f.keys(): return f.get_tensor(k) raise KeyError(k) return opened, get, keys def main(srcdir): opened, get, keys = open_any(srcdir) patterns = [ "block.{n}.mlp.mlp1_weight.scales", # openharmony-mlx 原生 "block.{n}.mlp.mlp2_weight.scales", "model.layers.{n}.mlp.experts.gate_up_proj_scales", # Jinx 命名 "model.layers.{n}.mlp.experts.down_proj_scales", ] import torch mx = [] for n in range(0, 64): # 够用 for pat in patterns: k = pat.format(n=n) try: t = get(k) except Exception: continue v = float(t.max().item()) mx.append(v) print(k, "max=", v) if mx: m = max(mx) print("\nGLOBAL MAX SCALE:", m, " (m + 14 =", m+14, ")") if m + 14 >= 256: print("⚠️ 会发生 uint8 溢出,必须用 int16+clamp 的写法。") if __name__ == "__main__": main(sys.argv[1])