Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
This commit is contained in:
516
gpt_oss/triton/model.py
Normal file
516
gpt_oss/triton/model.py
Normal file
@@ -0,0 +1,516 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.profiler import record_function
|
||||
|
||||
from gpt_oss.torch.model import ModelConfig, RMSNorm
|
||||
from gpt_oss.torch.weights import Checkpoint
|
||||
from gpt_oss.triton.attention import attention, attention_ref
|
||||
from gpt_oss.triton.moe import quantize_mx4, moe
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
base: int,
|
||||
dtype: torch.dtype,
|
||||
initial_context_length: int = 4096,
|
||||
max_context_length: int = 131072,
|
||||
scaling_factor: float = 1.0,
|
||||
ntk_alpha: float = 1.0,
|
||||
ntk_beta: float = 32.0,
|
||||
device: torch.device | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.base = base
|
||||
self.dtype = dtype
|
||||
self.initial_context_length = initial_context_length
|
||||
self.max_context_length = max_context_length
|
||||
self.scaling_factor = scaling_factor
|
||||
self.ntk_alpha = ntk_alpha
|
||||
self.ntk_beta = ntk_beta
|
||||
self.device = device
|
||||
self.cos, self.sin = self._compute_cos_sin(0, self.max_context_length)
|
||||
|
||||
def _compute_concentration_and_inv_freq(self) -> torch.Tensor:
|
||||
"""See YaRN paper: https://arxiv.org/abs/2309.00071"""
|
||||
freq = self.base ** (
|
||||
torch.arange(0, self.head_dim, 2, dtype=torch.float, device=self.device)
|
||||
/ self.head_dim
|
||||
)
|
||||
if self.scaling_factor > 1.0:
|
||||
concentration = (
|
||||
0.1 * math.log(self.scaling_factor) + 1.0
|
||||
) # YaRN concentration
|
||||
|
||||
d_half = self.head_dim / 2
|
||||
# NTK by parts
|
||||
low = (
|
||||
d_half
|
||||
* math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))
|
||||
/ math.log(self.base)
|
||||
)
|
||||
high = (
|
||||
d_half
|
||||
* math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))
|
||||
/ math.log(self.base)
|
||||
)
|
||||
assert 0 < low < high < d_half - 1
|
||||
|
||||
interpolation = 1.0 / (self.scaling_factor * freq)
|
||||
extrapolation = 1.0 / freq
|
||||
|
||||
ramp = (
|
||||
torch.arange(d_half, dtype=torch.float32, device=freq.device) - low
|
||||
) / (high - low)
|
||||
mask = 1 - ramp.clamp(0, 1)
|
||||
|
||||
inv_freq = interpolation * (1 - mask) + extrapolation * mask
|
||||
else:
|
||||
concentration = 1.0
|
||||
inv_freq = 1.0 / freq
|
||||
|
||||
return concentration, inv_freq
|
||||
|
||||
def _compute_cos_sin(self, start: int, num_tokens: int):
|
||||
concentration, inv_freq = self._compute_concentration_and_inv_freq()
|
||||
t = torch.arange(start, start + num_tokens, dtype=torch.float32, device=self.device)
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
||||
cos = freqs.cos() * concentration
|
||||
sin = freqs.sin() * concentration
|
||||
return cos, sin
|
||||
|
||||
@record_function("rotate")
|
||||
def _rotate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cos = cos[None, :, None, :].to(x.dtype)
|
||||
sin = sin[None, :, None, :].to(x.dtype)
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
|
||||
@record_function("rope")
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offset: torch.LongTensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_tokens, num_heads, head_dim = query.shape
|
||||
batch_size, num_tokens, num_key_value_heads, head_dim = key.shape
|
||||
|
||||
idx = torch.arange(num_tokens, device=query.device, dtype=torch.long) + offset
|
||||
idx = idx % self.max_context_length
|
||||
cos = self.cos.index_select(0, idx)
|
||||
sin = self.sin.index_select(0, idx)
|
||||
|
||||
query = self._rotate(query, cos, sin)
|
||||
key = self._rotate(key, cos, sin)
|
||||
return query, key
|
||||
|
||||
|
||||
class Cache:
|
||||
def __init__(self, batch_size, n_ctx, n_kv_heads, d_head=64, device: torch.device | None = None):
|
||||
self.k = torch.zeros((batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device)
|
||||
self.v = torch.zeros((batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device)
|
||||
self.offset = torch.zeros((1,), dtype=torch.long, device=device)
|
||||
|
||||
def reset(self):
|
||||
self.k.zero_()
|
||||
self.v.zero_()
|
||||
self.offset.zero_()
|
||||
|
||||
def repeat_interleave(self, n):
|
||||
"""Repeat each cache entry n times along the batch dimension."""
|
||||
self.k = self.k.repeat_interleave(n, dim=0)
|
||||
self.v = self.v.repeat_interleave(n, dim=0)
|
||||
|
||||
def truncate(self, n_ctx):
|
||||
"""Truncate the cache to the first n_ctx tokens."""
|
||||
batch_size, _, n_kv_heads, d_head = self.k.shape
|
||||
assert batch_size == self.v.shape[0]
|
||||
assert n_ctx <= self.k.shape[1]
|
||||
self.k[:, n_ctx:, :, :].zero_()
|
||||
self.v[:, n_ctx:, :, :].zero_()
|
||||
self.offset.fill_(n_ctx)
|
||||
return self.k, self.v
|
||||
|
||||
def extend(self, k, v):
|
||||
batch_size, n_ctx, *_rest = k.shape
|
||||
assert batch_size == self.k.shape[0]
|
||||
indices = torch.arange(0, n_ctx, device=k.device, dtype=torch.long) + self.offset
|
||||
self.k.index_copy_(1, indices, k)
|
||||
self.v.index_copy_(1, indices, v)
|
||||
self.offset.add_(n_ctx)
|
||||
return self.k, self.v
|
||||
|
||||
|
||||
class AttentionBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
layer_idx: int = 0,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = config.head_dim
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
# Only apply sliding window to every other layer
|
||||
self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else 0
|
||||
self.layer_idx = layer_idx
|
||||
self.sinks = torch.nn.Parameter(
|
||||
torch.empty(config.num_attention_heads, device=device, dtype=torch.bfloat16)
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, device=device)
|
||||
qkv_dim = config.head_dim * (
|
||||
config.num_attention_heads + 2 * config.num_key_value_heads
|
||||
)
|
||||
self.qkv = torch.nn.Linear(
|
||||
config.hidden_size, qkv_dim, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
self.out = torch.nn.Linear(
|
||||
config.head_dim * config.num_attention_heads,
|
||||
config.hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
self.sm_scale = 1 / math.sqrt(config.head_dim)
|
||||
self.rope = RotaryEmbedding(
|
||||
config.head_dim,
|
||||
config.rope_theta,
|
||||
torch.float32,
|
||||
initial_context_length=config.initial_context_length,
|
||||
scaling_factor=config.rope_scaling_factor,
|
||||
ntk_alpha=config.rope_ntk_alpha,
|
||||
ntk_beta=config.rope_ntk_beta,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@record_function("attn")
|
||||
def forward(self, x: torch.Tensor, cache: Cache | None = None) -> torch.Tensor:
|
||||
batch_size, n_ctx, dim = x.shape
|
||||
|
||||
t = self.norm(x)
|
||||
with record_function("qkv"):
|
||||
qkv = self.qkv(t)
|
||||
qkv_parts = (
|
||||
self.num_attention_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim
|
||||
)
|
||||
q, k, v = torch.split(qkv, qkv_parts, dim=-1)
|
||||
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
||||
|
||||
q = q.view(batch_size, n_ctx, self.num_attention_heads, self.head_dim)
|
||||
k = k.view(batch_size, n_ctx, self.num_key_value_heads, self.head_dim)
|
||||
v = v.view(batch_size, n_ctx, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
if cache is not None:
|
||||
offset = cache.offset.clone()
|
||||
q, k = self.rope(q, k, offset=offset)
|
||||
k, v = cache.extend(k, v)
|
||||
else:
|
||||
offset = torch.zeros((1,), dtype=torch.long, device=x.device)
|
||||
q, k = self.rope(q, k, offset=offset)
|
||||
|
||||
q = q.view(
|
||||
batch_size,
|
||||
n_ctx,
|
||||
self.num_attention_heads // self.num_key_value_heads,
|
||||
self.num_key_value_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
with record_function("attn_kernel"):
|
||||
if n_ctx == 1:
|
||||
t = attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
self.sinks,
|
||||
self.sm_scale,
|
||||
self.sliding_window,
|
||||
offset,
|
||||
)
|
||||
else:
|
||||
t = attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
self.sinks,
|
||||
self.sm_scale,
|
||||
self.sliding_window,
|
||||
offset,
|
||||
)
|
||||
if n_ctx < 64:
|
||||
t1 = attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
self.sinks,
|
||||
self.sm_scale,
|
||||
self.sliding_window,
|
||||
offset,
|
||||
)
|
||||
torch.testing.assert_close(t, t1)
|
||||
t = t1
|
||||
|
||||
with record_function("c_proj"):
|
||||
t = self.out(t)
|
||||
t = x + t
|
||||
return t
|
||||
|
||||
|
||||
class MLPBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
layer_idx: int = 0,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = config.num_experts
|
||||
self.experts_per_token = config.experts_per_token
|
||||
self.swiglu_limit = config.swiglu_limit
|
||||
self.norm = RMSNorm(config.hidden_size, device=device)
|
||||
self.gate = torch.nn.ParameterDict({
|
||||
"weight": torch.nn.Parameter(
|
||||
torch.empty(
|
||||
(config.hidden_size, config.num_experts),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
),
|
||||
"bias": torch.nn.Parameter(
|
||||
torch.empty(
|
||||
(config.num_experts,),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
),
|
||||
})
|
||||
self.mlp1_weight_tensor, self.mlp1_weight_mx = quantize_mx4(
|
||||
torch.empty(
|
||||
(
|
||||
config.num_experts,
|
||||
config.hidden_size,
|
||||
config.intermediate_size * 2,
|
||||
),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
)
|
||||
self.mlp1_weight = torch.nn.Parameter(self.mlp1_weight_tensor.storage.data, requires_grad=False)
|
||||
self.mlp1_bias = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
(config.num_experts, config.intermediate_size * 2),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
)
|
||||
self.mlp2_weight_tensor, self.mlp2_weight_mx = quantize_mx4(
|
||||
torch.empty(
|
||||
(
|
||||
config.num_experts,
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
)
|
||||
self.mlp2_weight = torch.nn.Parameter(self.mlp2_weight_tensor.storage.data, requires_grad=False)
|
||||
self.mlp2_bias = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
(config.num_experts, config.hidden_size),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
)
|
||||
|
||||
@record_function("mlp")
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, n_ctx, dim = x.shape
|
||||
t = self.norm(x)
|
||||
|
||||
t = t.view(batch_size * n_ctx, dim)
|
||||
t = moe(
|
||||
t,
|
||||
self.gate["weight"],
|
||||
self.mlp1_weight_tensor, self.mlp1_weight_mx,
|
||||
self.mlp2_weight_tensor, self.mlp2_weight_mx,
|
||||
self.gate["bias"].float(),
|
||||
self.mlp1_bias.float(),
|
||||
self.mlp2_bias.float(),
|
||||
experts_per_token=self.experts_per_token,
|
||||
num_experts=self.num_experts,
|
||||
swiglu_limit=self.swiglu_limit,
|
||||
)
|
||||
t = t.view(batch_size, n_ctx, dim)
|
||||
|
||||
return x + t
|
||||
|
||||
|
||||
class TransformerBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
layer_idx: int,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.attn = AttentionBlock(config, layer_idx, device)
|
||||
self.mlp = MLPBlock(config, layer_idx, device)
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Cache | None = None) -> torch.Tensor:
|
||||
x = self.attn(x, cache=cache)
|
||||
x = self.mlp(x)
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embedding = torch.nn.Embedding(
|
||||
config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
self.block = torch.nn.ModuleList(
|
||||
[
|
||||
TransformerBlock(config, layer_idx, device)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, device=device)
|
||||
self.unembedding = torch.nn.Linear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, caches: list[Cache] | None = None) -> torch.Tensor:
|
||||
caches=caches or [None] * len(self.block)
|
||||
with record_function("embedding"):
|
||||
x = self.embedding(x)
|
||||
for block, cache in zip(self.block, caches):
|
||||
with record_function("block"):
|
||||
x = block(x, cache=cache)
|
||||
with record_function("norm_f"):
|
||||
x = self.norm(x)
|
||||
with record_function("unembedding"):
|
||||
x = self.unembedding(x)
|
||||
return x.float()
|
||||
|
||||
@staticmethod
|
||||
def from_checkpoint(
|
||||
path: str, config: ModelConfig | None = None, device: str | torch.device = "cuda",
|
||||
) -> "Transformer":
|
||||
if not isinstance(device, torch.device):
|
||||
device = torch.device(device)
|
||||
|
||||
if config is None:
|
||||
config_path = os.path.join(path, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
json_config = json.load(f)
|
||||
config = ModelConfig(**json_config)
|
||||
|
||||
model = Transformer(config=config, device=device)
|
||||
model.eval()
|
||||
|
||||
checkpoint = Checkpoint(path, device)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
torch.cuda.empty_cache()
|
||||
loaded_tensor = checkpoint.get(name)
|
||||
|
||||
if "mlp1" in name:
|
||||
if "weight" in name:
|
||||
loaded_tensor, scales = quantize_mx4(loaded_tensor.mT.contiguous())
|
||||
_, block_index, _, _ = name.split(".")
|
||||
model.block[int(block_index)].mlp.mlp1_weight_mx = scales
|
||||
param.data.copy_(loaded_tensor.storage.data)
|
||||
else:
|
||||
param.data.copy_(loaded_tensor)
|
||||
|
||||
elif "mlp2_weight" in name:
|
||||
loaded_tensor, scales = quantize_mx4(loaded_tensor.mT.contiguous())
|
||||
_, block_index, _, _ = name.split(".")
|
||||
model.block[int(block_index)].mlp.mlp2_weight_mx = scales
|
||||
param.data.copy_(loaded_tensor.storage.data)
|
||||
|
||||
elif "gate" in name and loaded_tensor.ndim == 2:
|
||||
loaded_tensor = loaded_tensor.mT.contiguous()
|
||||
param.data.copy_(loaded_tensor)
|
||||
|
||||
else:
|
||||
param.data.copy_(loaded_tensor)
|
||||
|
||||
# NOTE: Required to avoid OOM errors
|
||||
torch.cuda.empty_cache()
|
||||
return model
|
||||
|
||||
|
||||
class TokenGenerator:
|
||||
@torch.inference_mode()
|
||||
def __init__(self, checkpoint: str, context: int, device: torch.device):
|
||||
self.device = device
|
||||
self.model = Transformer.from_checkpoint(checkpoint, device=self.device)
|
||||
self.caches = [Cache(1, context, self.model.config.num_key_value_heads, device=self.device) for _ in range(len(self.model.block))]
|
||||
self.input_token = torch.zeros(1, dtype=torch.int32, device=self.device)
|
||||
# warmup
|
||||
self.model(self.input_token[None, :], caches=self.caches)
|
||||
# capture for sampling
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.logits = self.model(self.input_token[None, :], caches=self.caches)[0]
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self,
|
||||
prompt_tokens: list[int],
|
||||
stop_tokens: list[int] | None = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: int = 0,
|
||||
return_logprobs: bool = False):
|
||||
stop_tokens = stop_tokens or []
|
||||
for cache in self.caches:
|
||||
cache.reset()
|
||||
prompt_tokens = torch.as_tensor(prompt_tokens, dtype=torch.int32, device=self.device)
|
||||
self.model(prompt_tokens[None, :-1], self.caches)
|
||||
predicted_token = prompt_tokens[-1]
|
||||
num_generated_tokens = 0
|
||||
while max_tokens == 0 or num_generated_tokens < max_tokens:
|
||||
self.input_token[0] = predicted_token
|
||||
self.graph.replay()
|
||||
if temperature == 0.0:
|
||||
predicted_token = torch.argmax(self.logits[-1, :], dim=-1).item()
|
||||
else:
|
||||
probs = torch.softmax(self.logits * (1.0 / temperature), dim=-1)
|
||||
predicted_token = torch.multinomial(probs[-1, :], num_samples=1).item()
|
||||
num_generated_tokens += 1
|
||||
|
||||
if return_logprobs:
|
||||
logprobs = torch.log_softmax(self.logits[-1, :], dim=-1)
|
||||
selected_logprobs = logprobs[predicted_token].item()
|
||||
yield predicted_token, selected_logprobs
|
||||
else:
|
||||
yield predicted_token
|
||||
|
||||
if predicted_token in stop_tokens:
|
||||
break
|
||||
Reference in New Issue
Block a user