Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
517 lines
18 KiB
Python
517 lines
18 KiB
Python
import json
|
|
import math
|
|
import os
|
|
|
|
import torch
|
|
from torch.profiler import record_function
|
|
|
|
from gpt_oss.torch.model import ModelConfig, RMSNorm
|
|
from gpt_oss.torch.weights import Checkpoint
|
|
from gpt_oss.triton.attention import attention, attention_ref
|
|
from gpt_oss.triton.moe import quantize_mx4, moe
|
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
head_dim: int,
|
|
base: int,
|
|
dtype: torch.dtype,
|
|
initial_context_length: int = 4096,
|
|
max_context_length: int = 131072,
|
|
scaling_factor: float = 1.0,
|
|
ntk_alpha: float = 1.0,
|
|
ntk_beta: float = 32.0,
|
|
device: torch.device | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.head_dim = head_dim
|
|
self.base = base
|
|
self.dtype = dtype
|
|
self.initial_context_length = initial_context_length
|
|
self.max_context_length = max_context_length
|
|
self.scaling_factor = scaling_factor
|
|
self.ntk_alpha = ntk_alpha
|
|
self.ntk_beta = ntk_beta
|
|
self.device = device
|
|
self.cos, self.sin = self._compute_cos_sin(0, self.max_context_length)
|
|
|
|
def _compute_concentration_and_inv_freq(self) -> torch.Tensor:
|
|
"""See YaRN paper: https://arxiv.org/abs/2309.00071"""
|
|
freq = self.base ** (
|
|
torch.arange(0, self.head_dim, 2, dtype=torch.float, device=self.device)
|
|
/ self.head_dim
|
|
)
|
|
if self.scaling_factor > 1.0:
|
|
concentration = (
|
|
0.1 * math.log(self.scaling_factor) + 1.0
|
|
) # YaRN concentration
|
|
|
|
d_half = self.head_dim / 2
|
|
# NTK by parts
|
|
low = (
|
|
d_half
|
|
* math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))
|
|
/ math.log(self.base)
|
|
)
|
|
high = (
|
|
d_half
|
|
* math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))
|
|
/ math.log(self.base)
|
|
)
|
|
assert 0 < low < high < d_half - 1
|
|
|
|
interpolation = 1.0 / (self.scaling_factor * freq)
|
|
extrapolation = 1.0 / freq
|
|
|
|
ramp = (
|
|
torch.arange(d_half, dtype=torch.float32, device=freq.device) - low
|
|
) / (high - low)
|
|
mask = 1 - ramp.clamp(0, 1)
|
|
|
|
inv_freq = interpolation * (1 - mask) + extrapolation * mask
|
|
else:
|
|
concentration = 1.0
|
|
inv_freq = 1.0 / freq
|
|
|
|
return concentration, inv_freq
|
|
|
|
def _compute_cos_sin(self, start: int, num_tokens: int):
|
|
concentration, inv_freq = self._compute_concentration_and_inv_freq()
|
|
t = torch.arange(start, start + num_tokens, dtype=torch.float32, device=self.device)
|
|
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
cos = freqs.cos() * concentration
|
|
sin = freqs.sin() * concentration
|
|
return cos, sin
|
|
|
|
@record_function("rotate")
|
|
def _rotate(
|
|
self,
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
cos = cos[None, :, None, :].to(x.dtype)
|
|
sin = sin[None, :, None, :].to(x.dtype)
|
|
x1, x2 = torch.chunk(x, 2, dim=-1)
|
|
o1 = x1 * cos - x2 * sin
|
|
o2 = x2 * cos + x1 * sin
|
|
return torch.cat((o1, o2), dim=-1)
|
|
|
|
@record_function("rope")
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
offset: torch.LongTensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
batch_size, num_tokens, num_heads, head_dim = query.shape
|
|
batch_size, num_tokens, num_key_value_heads, head_dim = key.shape
|
|
|
|
idx = torch.arange(num_tokens, device=query.device, dtype=torch.long) + offset
|
|
idx = idx % self.max_context_length
|
|
cos = self.cos.index_select(0, idx)
|
|
sin = self.sin.index_select(0, idx)
|
|
|
|
query = self._rotate(query, cos, sin)
|
|
key = self._rotate(key, cos, sin)
|
|
return query, key
|
|
|
|
|
|
class Cache:
|
|
def __init__(self, batch_size, n_ctx, n_kv_heads, d_head=64, device: torch.device | None = None):
|
|
self.k = torch.zeros((batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device)
|
|
self.v = torch.zeros((batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device)
|
|
self.offset = torch.zeros((1,), dtype=torch.long, device=device)
|
|
|
|
def reset(self):
|
|
self.k.zero_()
|
|
self.v.zero_()
|
|
self.offset.zero_()
|
|
|
|
def repeat_interleave(self, n):
|
|
"""Repeat each cache entry n times along the batch dimension."""
|
|
self.k = self.k.repeat_interleave(n, dim=0)
|
|
self.v = self.v.repeat_interleave(n, dim=0)
|
|
|
|
def truncate(self, n_ctx):
|
|
"""Truncate the cache to the first n_ctx tokens."""
|
|
batch_size, _, n_kv_heads, d_head = self.k.shape
|
|
assert batch_size == self.v.shape[0]
|
|
assert n_ctx <= self.k.shape[1]
|
|
self.k[:, n_ctx:, :, :].zero_()
|
|
self.v[:, n_ctx:, :, :].zero_()
|
|
self.offset.fill_(n_ctx)
|
|
return self.k, self.v
|
|
|
|
def extend(self, k, v):
|
|
batch_size, n_ctx, *_rest = k.shape
|
|
assert batch_size == self.k.shape[0]
|
|
indices = torch.arange(0, n_ctx, device=k.device, dtype=torch.long) + self.offset
|
|
self.k.index_copy_(1, indices, k)
|
|
self.v.index_copy_(1, indices, v)
|
|
self.offset.add_(n_ctx)
|
|
return self.k, self.v
|
|
|
|
|
|
class AttentionBlock(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: ModelConfig,
|
|
layer_idx: int = 0,
|
|
device: torch.device | None = None,
|
|
):
|
|
super().__init__()
|
|
self.head_dim = config.head_dim
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
# Only apply sliding window to every other layer
|
|
self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else 0
|
|
self.layer_idx = layer_idx
|
|
self.sinks = torch.nn.Parameter(
|
|
torch.empty(config.num_attention_heads, device=device, dtype=torch.bfloat16)
|
|
)
|
|
self.norm = RMSNorm(config.hidden_size, device=device)
|
|
qkv_dim = config.head_dim * (
|
|
config.num_attention_heads + 2 * config.num_key_value_heads
|
|
)
|
|
self.qkv = torch.nn.Linear(
|
|
config.hidden_size, qkv_dim, device=device, dtype=torch.bfloat16
|
|
)
|
|
self.out = torch.nn.Linear(
|
|
config.head_dim * config.num_attention_heads,
|
|
config.hidden_size,
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
self.sm_scale = 1 / math.sqrt(config.head_dim)
|
|
self.rope = RotaryEmbedding(
|
|
config.head_dim,
|
|
config.rope_theta,
|
|
torch.float32,
|
|
initial_context_length=config.initial_context_length,
|
|
scaling_factor=config.rope_scaling_factor,
|
|
ntk_alpha=config.rope_ntk_alpha,
|
|
ntk_beta=config.rope_ntk_beta,
|
|
device=device,
|
|
)
|
|
|
|
@record_function("attn")
|
|
def forward(self, x: torch.Tensor, cache: Cache | None = None) -> torch.Tensor:
|
|
batch_size, n_ctx, dim = x.shape
|
|
|
|
t = self.norm(x)
|
|
with record_function("qkv"):
|
|
qkv = self.qkv(t)
|
|
qkv_parts = (
|
|
self.num_attention_heads * self.head_dim,
|
|
self.num_key_value_heads * self.head_dim,
|
|
self.num_key_value_heads * self.head_dim
|
|
)
|
|
q, k, v = torch.split(qkv, qkv_parts, dim=-1)
|
|
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
|
|
|
q = q.view(batch_size, n_ctx, self.num_attention_heads, self.head_dim)
|
|
k = k.view(batch_size, n_ctx, self.num_key_value_heads, self.head_dim)
|
|
v = v.view(batch_size, n_ctx, self.num_key_value_heads, self.head_dim)
|
|
|
|
if cache is not None:
|
|
offset = cache.offset.clone()
|
|
q, k = self.rope(q, k, offset=offset)
|
|
k, v = cache.extend(k, v)
|
|
else:
|
|
offset = torch.zeros((1,), dtype=torch.long, device=x.device)
|
|
q, k = self.rope(q, k, offset=offset)
|
|
|
|
q = q.view(
|
|
batch_size,
|
|
n_ctx,
|
|
self.num_attention_heads // self.num_key_value_heads,
|
|
self.num_key_value_heads,
|
|
self.head_dim,
|
|
)
|
|
with record_function("attn_kernel"):
|
|
if n_ctx == 1:
|
|
t = attention_ref(
|
|
q,
|
|
k,
|
|
v,
|
|
self.sinks,
|
|
self.sm_scale,
|
|
self.sliding_window,
|
|
offset,
|
|
)
|
|
else:
|
|
t = attention(
|
|
q,
|
|
k,
|
|
v,
|
|
self.sinks,
|
|
self.sm_scale,
|
|
self.sliding_window,
|
|
offset,
|
|
)
|
|
if n_ctx < 64:
|
|
t1 = attention_ref(
|
|
q,
|
|
k,
|
|
v,
|
|
self.sinks,
|
|
self.sm_scale,
|
|
self.sliding_window,
|
|
offset,
|
|
)
|
|
torch.testing.assert_close(t, t1)
|
|
t = t1
|
|
|
|
with record_function("c_proj"):
|
|
t = self.out(t)
|
|
t = x + t
|
|
return t
|
|
|
|
|
|
class MLPBlock(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: ModelConfig,
|
|
layer_idx: int = 0,
|
|
device: torch.device | None = None,
|
|
):
|
|
super().__init__()
|
|
self.layer_idx = layer_idx
|
|
self.num_experts = config.num_experts
|
|
self.experts_per_token = config.experts_per_token
|
|
self.swiglu_limit = config.swiglu_limit
|
|
self.norm = RMSNorm(config.hidden_size, device=device)
|
|
self.gate = torch.nn.ParameterDict({
|
|
"weight": torch.nn.Parameter(
|
|
torch.empty(
|
|
(config.hidden_size, config.num_experts),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
),
|
|
"bias": torch.nn.Parameter(
|
|
torch.empty(
|
|
(config.num_experts,),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
),
|
|
})
|
|
self.mlp1_weight_tensor, self.mlp1_weight_mx = quantize_mx4(
|
|
torch.empty(
|
|
(
|
|
config.num_experts,
|
|
config.hidden_size,
|
|
config.intermediate_size * 2,
|
|
),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
),
|
|
)
|
|
self.mlp1_weight = torch.nn.Parameter(self.mlp1_weight_tensor.storage.data, requires_grad=False)
|
|
self.mlp1_bias = torch.nn.Parameter(
|
|
torch.empty(
|
|
(config.num_experts, config.intermediate_size * 2),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
)
|
|
self.mlp2_weight_tensor, self.mlp2_weight_mx = quantize_mx4(
|
|
torch.empty(
|
|
(
|
|
config.num_experts,
|
|
config.intermediate_size,
|
|
config.hidden_size,
|
|
),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
),
|
|
)
|
|
self.mlp2_weight = torch.nn.Parameter(self.mlp2_weight_tensor.storage.data, requires_grad=False)
|
|
self.mlp2_bias = torch.nn.Parameter(
|
|
torch.empty(
|
|
(config.num_experts, config.hidden_size),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
)
|
|
|
|
@record_function("mlp")
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
batch_size, n_ctx, dim = x.shape
|
|
t = self.norm(x)
|
|
|
|
t = t.view(batch_size * n_ctx, dim)
|
|
t = moe(
|
|
t,
|
|
self.gate["weight"],
|
|
self.mlp1_weight_tensor, self.mlp1_weight_mx,
|
|
self.mlp2_weight_tensor, self.mlp2_weight_mx,
|
|
self.gate["bias"].float(),
|
|
self.mlp1_bias.float(),
|
|
self.mlp2_bias.float(),
|
|
experts_per_token=self.experts_per_token,
|
|
num_experts=self.num_experts,
|
|
swiglu_limit=self.swiglu_limit,
|
|
)
|
|
t = t.view(batch_size, n_ctx, dim)
|
|
|
|
return x + t
|
|
|
|
|
|
class TransformerBlock(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: ModelConfig,
|
|
layer_idx: int,
|
|
device: torch.device | None = None,
|
|
):
|
|
super().__init__()
|
|
self.layer_idx = layer_idx
|
|
self.attn = AttentionBlock(config, layer_idx, device)
|
|
self.mlp = MLPBlock(config, layer_idx, device)
|
|
|
|
def forward(self, x: torch.Tensor, cache: Cache | None = None) -> torch.Tensor:
|
|
x = self.attn(x, cache=cache)
|
|
x = self.mlp(x)
|
|
return x
|
|
|
|
|
|
class Transformer(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: ModelConfig,
|
|
device: torch.device | None = None,
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embedding = torch.nn.Embedding(
|
|
config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16
|
|
)
|
|
self.block = torch.nn.ModuleList(
|
|
[
|
|
TransformerBlock(config, layer_idx, device)
|
|
for layer_idx in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.norm = RMSNorm(config.hidden_size, device=device)
|
|
self.unembedding = torch.nn.Linear(
|
|
config.hidden_size,
|
|
config.vocab_size,
|
|
bias=False,
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, caches: list[Cache] | None = None) -> torch.Tensor:
|
|
caches=caches or [None] * len(self.block)
|
|
with record_function("embedding"):
|
|
x = self.embedding(x)
|
|
for block, cache in zip(self.block, caches):
|
|
with record_function("block"):
|
|
x = block(x, cache=cache)
|
|
with record_function("norm_f"):
|
|
x = self.norm(x)
|
|
with record_function("unembedding"):
|
|
x = self.unembedding(x)
|
|
return x.float()
|
|
|
|
@staticmethod
|
|
def from_checkpoint(
|
|
path: str, config: ModelConfig | None = None, device: str | torch.device = "cuda",
|
|
) -> "Transformer":
|
|
if not isinstance(device, torch.device):
|
|
device = torch.device(device)
|
|
|
|
if config is None:
|
|
config_path = os.path.join(path, "config.json")
|
|
with open(config_path, "r") as f:
|
|
json_config = json.load(f)
|
|
config = ModelConfig(**json_config)
|
|
|
|
model = Transformer(config=config, device=device)
|
|
model.eval()
|
|
|
|
checkpoint = Checkpoint(path, device)
|
|
|
|
for name, param in model.named_parameters():
|
|
torch.cuda.empty_cache()
|
|
loaded_tensor = checkpoint.get(name)
|
|
|
|
if "mlp1" in name:
|
|
if "weight" in name:
|
|
loaded_tensor, scales = quantize_mx4(loaded_tensor.mT.contiguous())
|
|
_, block_index, _, _ = name.split(".")
|
|
model.block[int(block_index)].mlp.mlp1_weight_mx = scales
|
|
param.data.copy_(loaded_tensor.storage.data)
|
|
else:
|
|
param.data.copy_(loaded_tensor)
|
|
|
|
elif "mlp2_weight" in name:
|
|
loaded_tensor, scales = quantize_mx4(loaded_tensor.mT.contiguous())
|
|
_, block_index, _, _ = name.split(".")
|
|
model.block[int(block_index)].mlp.mlp2_weight_mx = scales
|
|
param.data.copy_(loaded_tensor.storage.data)
|
|
|
|
elif "gate" in name and loaded_tensor.ndim == 2:
|
|
loaded_tensor = loaded_tensor.mT.contiguous()
|
|
param.data.copy_(loaded_tensor)
|
|
|
|
else:
|
|
param.data.copy_(loaded_tensor)
|
|
|
|
# NOTE: Required to avoid OOM errors
|
|
torch.cuda.empty_cache()
|
|
return model
|
|
|
|
|
|
class TokenGenerator:
|
|
@torch.inference_mode()
|
|
def __init__(self, checkpoint: str, context: int, device: torch.device):
|
|
self.device = device
|
|
self.model = Transformer.from_checkpoint(checkpoint, device=self.device)
|
|
self.caches = [Cache(1, context, self.model.config.num_key_value_heads, device=self.device) for _ in range(len(self.model.block))]
|
|
self.input_token = torch.zeros(1, dtype=torch.int32, device=self.device)
|
|
# warmup
|
|
self.model(self.input_token[None, :], caches=self.caches)
|
|
# capture for sampling
|
|
self.graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(self.graph):
|
|
self.logits = self.model(self.input_token[None, :], caches=self.caches)[0]
|
|
|
|
@torch.inference_mode()
|
|
def generate(self,
|
|
prompt_tokens: list[int],
|
|
stop_tokens: list[int] | None = None,
|
|
temperature: float = 1.0,
|
|
max_tokens: int = 0,
|
|
return_logprobs: bool = False):
|
|
stop_tokens = stop_tokens or []
|
|
for cache in self.caches:
|
|
cache.reset()
|
|
prompt_tokens = torch.as_tensor(prompt_tokens, dtype=torch.int32, device=self.device)
|
|
self.model(prompt_tokens[None, :-1], self.caches)
|
|
predicted_token = prompt_tokens[-1]
|
|
num_generated_tokens = 0
|
|
while max_tokens == 0 or num_generated_tokens < max_tokens:
|
|
self.input_token[0] = predicted_token
|
|
self.graph.replay()
|
|
if temperature == 0.0:
|
|
predicted_token = torch.argmax(self.logits[-1, :], dim=-1).item()
|
|
else:
|
|
probs = torch.softmax(self.logits * (1.0 / temperature), dim=-1)
|
|
predicted_token = torch.multinomial(probs[-1, :], num_samples=1).item()
|
|
num_generated_tokens += 1
|
|
|
|
if return_logprobs:
|
|
logprobs = torch.log_softmax(self.logits[-1, :], dim=-1)
|
|
selected_logprobs = logprobs[predicted_token].item()
|
|
yield predicted_token, selected_logprobs
|
|
else:
|
|
yield predicted_token
|
|
|
|
if predicted_token in stop_tokens:
|
|
break
|