#pragma once #include #include #include #include "internal/metal.h" struct gptoss_tokenizer { atomic_uint_least64_t ref_count; void* mapping_ptr; size_t mapping_size; const char* regex_ptr; const char* tokens_ptr; uint32_t num_text_tokens; uint32_t num_special_tokens; uint32_t special_token_id[gptoss_special_token_max - 1]; }; struct gptoss_model { atomic_uint_least64_t ref_count; struct gptoss_tokenizer* tokenizer; void* mapping_ptr; size_t mapping_size; uint32_t context_length; uint32_t num_blocks; uint32_t num_experts; uint32_t num_active_experts; uint32_t embedding_dim; uint32_t mlp_dim; float swiglu_limit; uint32_t head_dim; uint32_t num_heads; uint32_t num_kv_heads; uint32_t attention_window; float rope_theta; float interpolation_scale; float yarn_offset; float yarn_scale; float yarn_multiplier; float rmsnorm_epsilon; uint32_t vocabulary_size; // Maximum number of tokens that can be processed in a single batch. // Once the batch size is reached, we process it to fill the KV cache. size_t max_batch_tokens; size_t weights_size; size_t allocation_size; // Metal objects struct gptoss_metal_device device; size_t max_threadgroups; struct gptoss_metal_command_queue command_queue; struct gptoss_metal_library library; struct gptoss_metal_function bf16_f32_embeddings_fn; struct gptoss_metal_function f32_bf16w_rmsnorm_fn; struct gptoss_metal_function f32_bf16w_matmul_fn; struct gptoss_metal_function f32_bf16w_unembedding_fn; struct gptoss_metal_function f32_rope_fn; struct gptoss_metal_function f32_mf4w_moe_matmul_swiglu_fn; struct gptoss_metal_function f32_mf4w_moe_matmul_fn; struct gptoss_metal_function f32_accumulate_e4_fn; struct gptoss_metal_function f32_topk_softmax_e32_k4_fn; struct gptoss_metal_function f32_topk_softmax_e128_k4_fn; struct gptoss_metal_function f32_sdpa_q8_d64_fn; struct gptoss_metal_function f32_softmax_fn; // Activation buffers. // TODO: merge into a single buffer. struct gptoss_metal_buffer residual_activation_buffer; // Residual stream struct gptoss_metal_buffer rmsnorm_activation_buffer; // Both attention & MLP RMSNorm output struct gptoss_metal_buffer qkv_activation_buffer; // QKV projection output struct gptoss_metal_buffer sdpa_activation_buffer; // SDPA output struct gptoss_metal_buffer gate_activation_buffer; // MoE gating output struct gptoss_metal_buffer expert_activation_buffer; // MoE expert predictions struct gptoss_metal_buffer swiglu_activation_buffer; // MLP+SwiGLU output struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert) size_t per_block_shared_weights_size; size_t per_expert_block_weight_size; size_t attn_rmsnorm_gain_offset; size_t attn_qkv_weight_offset; size_t attn_qkv_bias_offset; size_t attn_sdpa_sink_offset; size_t attn_out_weight_offset; size_t attn_out_bias_offset; size_t mlp_rmsnorm_gain_offset; size_t mlp_gate_weight_offset; size_t mlp_gate_bias_offset; size_t mlp_swiglu_scale_offset; size_t mlp_swiglu_bias_offset; size_t mlp_out_block_offset; size_t mlp_out_scale_offset; size_t mlp_out_bias_offset; size_t rmsnorm_weight_offset; size_t unembedding_weight_offset; // Buffer with non-MoE weights. Includes MoE gates, embeddings/unembeddings. struct gptoss_metal_buffer shared_weight_buffer; // num_blocks per-block buffers with MoE weights to follow. struct gptoss_metal_buffer block_weight_buffers[]; }; #define GPTOSS_DEFAULT_BATCH_SIZE 128 struct gptoss_context { atomic_uint_least64_t ref_count; struct gptoss_model* model; // Number of tokens processed in the context. size_t num_tokens; // Number of tokens in the KV cache. size_t num_kv_tokens; // Length of the context. size_t max_tokens; // Current number of tokens in the batch. // Always in the [0, max_batch_tokens) range. size_t num_batch_tokens; // Number of tokens processed in the last batch. // Activations for [num_batch_tokens, num_processed_tokens) tokens can be accessed from internal structures. size_t num_processed_tokens; size_t kvcache_size; size_t allocation_size; struct gptoss_metal_buffer token_buffer; // uint32 token IDs struct gptoss_metal_buffer score_buffer; // unembedding outputs struct gptoss_metal_buffer prob_buffer; struct gptoss_metal_buffer sum_buffer; struct gptoss_metal_buffer argmax_buffer; struct gptoss_metal_buffer kvcache_buffer; }; struct gptoss_sampler { atomic_uint_least64_t ref_count; float temperature; float top_p; float presence_penalty; float frequency_penalty; };