Initial commit

Co-authored-by: Zhuohan Li <zhuohan@openai.com>
Co-authored-by: Maratyszcza <marat@openai.com>
Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
This commit is contained in:
Dominik Kundel
2025-08-05 08:19:49 -07:00
commit 243a1b0276
124 changed files with 20405 additions and 0 deletions

View File

@@ -0,0 +1,102 @@
import datetime
import os
from typing import Callable
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.distributed as dist
from gpt_oss.triton.model import Cache, ModelConfig, Transformer
DEFAULT_TEMPERATURE = 0.0
CONTEXT = 16_384
CONCURRENT_SESSIONS = 1
rank = int(
os.environ.get("RANK", 0)
) # set this env var to another value to run on other GPUs
def load_model(checkpoint: str):
print(f"[{rank}] loading model...")
torch.cuda.set_device(rank)
torch.set_grad_enabled(False)
device = torch.device(f"cuda:{rank}")
# Load model
model = Transformer.from_checkpoint(checkpoint, device=device)
print(f"[{rank}] loaded")
return model, device
def get_infer_next_token(model, device):
caches = [
Cache(CONCURRENT_SESSIONS, CONTEXT, model.config.num_key_value_heads)
for _ in range(len(model.block))
]
# offsets = torch.zeros(CONCURRENT_SESSIONS, dtype=torch.int32, device=device) # TBD
input_token = torch.zeros(
1, dtype=torch.int32, device=device
) # add concurrent sessions support
tokens_so_far = []
model.prefill(torch.zeros(1, 4, dtype=torch.int32, device=device), caches)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
logits = model(input_token[None, :], caches=caches)[0]
def lcp(cache: list[int], inp: list[int]) -> list[int]:
i = 0
max_len = min(len(cache), len(inp))
while i < max_len and cache[i] == inp[i]:
i += 1
return cache[:i]
def sample_next_token(
logits: torch.Tensor, temperature: float = DEFAULT_TEMPERATURE
) -> int:
"""Executed only on rank 0."""
if temperature == 0.0:
return torch.argmax(logits[-1, :], dim=-1).item()
probs = torch.softmax(logits * (1.0 / temperature), dim=-1)
return torch.multinomial(probs[-1, :], num_samples=1).item()
@torch.inference_mode()
def infer_next_token(
tokens: list[int],
temperature: float = DEFAULT_TEMPERATURE,
new_request: bool = False,
) -> int:
nonlocal tokens_so_far
tokens_so_far = lcp(tokens_so_far, tokens)
for cache in caches:
cache.truncate(len(tokens_so_far))
all_tokens = tokens # for pdb
tokens = tokens[len(tokens_so_far) :]
if len(tokens) > 1:
model.prefill(
torch.as_tensor(tokens[:-1], dtype=torch.int32, device=device)[None, :],
caches,
)
if len(tokens) == 0:
breakpoint()
input_token[-1] = tokens[-1]
graph.replay()
# decide next token on rank0
next_tok = sample_next_token(logits, temperature=temperature)
return next_tok
return infer_next_token
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
model, device = load_model(checkpoint)
infer_next_token = get_infer_next_token(model, device)
return infer_next_token