import argparse import os import math import sys import json import itertools import struct from uuid import UUID import tiktoken import torch from safetensors import safe_open from tqdm import tqdm from openai_harmony import load_harmony_encoding, HarmonyEncodingName parser = argparse.ArgumentParser(prog='check-mxfp4-weights.py', description='Validated MXFP4 weights') parser.add_argument('-s', '--src', metavar='DIR', type=str, required=True, help='Path to the input checkpoint directory') parser.add_argument('-d', '--dst', metavar='FILE', type=str, required=True, help='Path to the output model file') o200k_base = tiktoken.get_encoding("o200k_base") harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) o200k_gptoss = tiktoken.Encoding( name="o200k_gptoss", pat_str=o200k_base._pat_str, mergeable_ranks=o200k_base._mergeable_ranks, special_tokens={ "<|reversed199998|>": 199998, # unused "<|endoftext|>": 199999, "<|untrusted|>": 200000, "<|endofuntrusted|>": 200001, "<|return|>": 200002, "<|constrain|>": 200003, "<|reversed200004|>": 200004, # unused "<|channel|>": 200005, "<|start|>": 200006, "<|end|>": 200007, "<|message|>": 200008, "<|reversed200008|>": 200008, # unused "<|reversed200009|>": 200009, # unused "<|reversed200010|>": 200010, # unused "<|reversed200011|>": 200011, # unused "<|call|>": 200012, "<|refusal|>": 200013, } ) FILE_MAGIC = struct.pack('ccccccccccccI', b'G', b'P', b'T', b'-', b'O', b'S', b'S', b' ', b'v', b'1', b'.', b'0', 0) SPECIAL_TOKEN_UUID = { '<|start|>': UUID('55a77c2f-8a01-4c54-8ac2-313bfc7e208d').bytes, '<|message|>': UUID('16e40431-f47f-4b22-b59b-8b278fc30a54').bytes, '<|end|>': UUID('fcac2f6d-4705-4f6b-b228-642accac7238').bytes, '<|return|>': UUID('f799ff69-1992-43c4-a3d8-d831f475dc75').bytes, '<|refusal|>': UUID('e15ba702-28c4-4292-ab8f-ffa434709128').bytes, '<|constrain|>': UUID('c0bb14c7-6022-49da-ad08-792d67e8b470').bytes, '<|channel|>': UUID('fd3dda11-c8ab-4033-876e-d93deb172c93').bytes, '<|call|>': UUID('1220f796-e388-4de5-b487-fe2eb5fe03c0').bytes, '<|untrusted|>': UUID('07d7da55-b346-4cff-8b37-7cefacf8a3e8').bytes, '<|end_untrusted|>': UUID('f265bd9c-c717-469e-a447-920687d65d90').bytes, } INCLUDE_SPECIAL_TOKENS = [ "<|start|>", "<|message|>", "<|end|>", "<|return|>", "<|refusal|>", "<|constrain|>", "<|channel|>", "<|call|>", "<|untrusted|>", "<|end_untrusted|>", ] GPTOSS_MODEL_UUID = UUID('df52dc86-1789-4ed0-a295-66f10508145b').bytes APPLE_GPU_LAYOUT_UUID = UUID('229177a8-5775-4268-bfd8-d588b351c56d').bytes TIKTOKEN_TOKENIZER_UUID = UUID('7401aded-2a95-40cb-b782-9ccebaafe72b').bytes UE8_OFFSET = 14 # bias to MXFP4 block scales def write_file_header(f): f.write(FILE_MAGIC) def write_tokenizer_header(f, num_special_tokens: int, num_text_tokens: int, regex_size: int, tokens_size: int): f.write(TIKTOKEN_TOKENIZER_UUID) f.write(struct.pack(' 0 tokens_size += len(token_bytes) + 2 # uint16_t string length + string data num_text_tokens += 1 # Then add all special tokens num_included_tokens = 200013 + 1 print(f"Tokenizer: {num_included_tokens} tokens") tensors = {} with open(options.dst, "wb") as dst: with safe_open(os.path.join(options.src, "model.safetensors"), framework="pt", device="cpu") as src: write_file_header(dst) yarn_low = ( head_dim / 2 * math.log(initial_context_length / (rope_ntk_beta * 2 * math.pi)) / math.log(rope_theta) ) yarn_high = ( head_dim / 2 * math.log(initial_context_length / (rope_ntk_alpha * 2 * math.pi)) / math.log(rope_theta) ) write_model_header(dst, context_length=int(initial_context_length * rope_scaling_factor), num_blocks=num_blocks, num_experts=num_experts, num_active_experts=num_active_experts, embedding_dim=embedding_dim, mlp_dim=mlp_dim, swiglu_limit=swiglu_limit, head_dim=head_dim, num_heads=num_q_heads, num_kv_heads=num_kv_heads, attention_window=attention_window, rope_theta=rope_theta, interpolation_scale=1.0 / rope_scaling_factor, yarn_offset=-yarn_low / (yarn_high - yarn_low), yarn_scale=1.0 / (yarn_high - yarn_low), yarn_multiplier=0.1 * math.log(rope_scaling_factor) + 1.0, rmsnorm_epsilon=1.0e-5) write_tokenizer_header(dst, num_special_tokens=num_included_tokens - num_text_tokens, num_text_tokens=num_text_tokens, regex_size=len(o200k_gptoss._pat_str.encode("ascii")) + 1, tokens_size=tokens_size) ### Tokenizer # Special tokens for token_idx in range(num_text_tokens, num_included_tokens): token = o200k_gptoss.decode_single_token_bytes(token_idx).decode('ascii') if token in INCLUDE_SPECIAL_TOKENS: dst.write(SPECIAL_TOKEN_UUID[token]) else: dst.write(bytes(16)) # Regex dst.write(o200k_gptoss._pat_str.encode("ascii")) dst.write(struct.pack('B', 0)) # Text tokens tokenizer_bytes_written = 0 for t in range(num_text_tokens): token_bytes = o200k_gptoss.decode_single_token_bytes(t) assert len(token_bytes) > 0 dst.write(struct.pack('