Files
openharmony-mlx/gpt_oss/responses_api/inference/triton.py
Dominik Kundel 243a1b0276 Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com>
Co-authored-by: Maratyszcza <marat@openai.com>
Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
2025-08-05 08:19:49 -07:00

103 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import datetime
import os
from typing import Callable
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.distributed as dist
from gpt_oss.triton.model import Cache, ModelConfig, Transformer
DEFAULT_TEMPERATURE = 0.0
CONTEXT = 16_384
CONCURRENT_SESSIONS = 1
rank = int(
os.environ.get("RANK", 0)
) # set this env var to another value to run on other GPUs
def load_model(checkpoint: str):
print(f"[{rank}] loading model...")
torch.cuda.set_device(rank)
torch.set_grad_enabled(False)
device = torch.device(f"cuda:{rank}")
# Load model
model = Transformer.from_checkpoint(checkpoint, device=device)
print(f"[{rank}] loaded")
return model, device
def get_infer_next_token(model, device):
caches = [
Cache(CONCURRENT_SESSIONS, CONTEXT, model.config.num_key_value_heads)
for _ in range(len(model.block))
]
# offsets = torch.zeros(CONCURRENT_SESSIONS, dtype=torch.int32, device=device) # TBD
input_token = torch.zeros(
1, dtype=torch.int32, device=device
) # add concurrent sessions support
tokens_so_far = []
model.prefill(torch.zeros(1, 4, dtype=torch.int32, device=device), caches)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
logits = model(input_token[None, :], caches=caches)[0]
def lcp(cache: list[int], inp: list[int]) -> list[int]:
i = 0
max_len = min(len(cache), len(inp))
while i < max_len and cache[i] == inp[i]:
i += 1
return cache[:i]
def sample_next_token(
logits: torch.Tensor, temperature: float = DEFAULT_TEMPERATURE
) -> int:
"""Executed only on rank 0."""
if temperature == 0.0:
return torch.argmax(logits[-1, :], dim=-1).item()
probs = torch.softmax(logits * (1.0 / temperature), dim=-1)
return torch.multinomial(probs[-1, :], num_samples=1).item()
@torch.inference_mode()
def infer_next_token(
tokens: list[int],
temperature: float = DEFAULT_TEMPERATURE,
new_request: bool = False,
) -> int:
nonlocal tokens_so_far
tokens_so_far = lcp(tokens_so_far, tokens)
for cache in caches:
cache.truncate(len(tokens_so_far))
all_tokens = tokens # for pdb
tokens = tokens[len(tokens_so_far) :]
if len(tokens) > 1:
model.prefill(
torch.as_tensor(tokens[:-1], dtype=torch.int32, device=device)[None, :],
caches,
)
if len(tokens) == 0:
breakpoint()
input_token[-1] = tokens[-1]
graph.replay()
# decide next token on rank0
next_tok = sample_next_token(logits, temperature=temperature)
return next_tok
return infer_next_token
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
model, device = load_model(checkpoint)
infer_next_token = get_infer_next_token(model, device)
return infer_next_token