Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
This commit is contained in:
102
gpt_oss/responses_api/inference/triton.py
Normal file
102
gpt_oss/responses_api/inference/triton.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import datetime
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from gpt_oss.triton.model import Cache, ModelConfig, Transformer
|
||||
|
||||
DEFAULT_TEMPERATURE = 0.0
|
||||
CONTEXT = 16_384
|
||||
CONCURRENT_SESSIONS = 1
|
||||
|
||||
rank = int(
|
||||
os.environ.get("RANK", 0)
|
||||
) # set this env var to another value to run on other GPUs
|
||||
|
||||
|
||||
def load_model(checkpoint: str):
|
||||
print(f"[{rank}] loading model...")
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
torch.set_grad_enabled(False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
# Load model
|
||||
model = Transformer.from_checkpoint(checkpoint, device=device)
|
||||
|
||||
print(f"[{rank}] loaded")
|
||||
return model, device
|
||||
|
||||
|
||||
def get_infer_next_token(model, device):
|
||||
caches = [
|
||||
Cache(CONCURRENT_SESSIONS, CONTEXT, model.config.num_key_value_heads)
|
||||
for _ in range(len(model.block))
|
||||
]
|
||||
# offsets = torch.zeros(CONCURRENT_SESSIONS, dtype=torch.int32, device=device) # TBD
|
||||
input_token = torch.zeros(
|
||||
1, dtype=torch.int32, device=device
|
||||
) # add concurrent sessions support
|
||||
tokens_so_far = []
|
||||
|
||||
model.prefill(torch.zeros(1, 4, dtype=torch.int32, device=device), caches)
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
logits = model(input_token[None, :], caches=caches)[0]
|
||||
|
||||
def lcp(cache: list[int], inp: list[int]) -> list[int]:
|
||||
i = 0
|
||||
max_len = min(len(cache), len(inp))
|
||||
while i < max_len and cache[i] == inp[i]:
|
||||
i += 1
|
||||
return cache[:i]
|
||||
|
||||
def sample_next_token(
|
||||
logits: torch.Tensor, temperature: float = DEFAULT_TEMPERATURE
|
||||
) -> int:
|
||||
"""Executed only on rank 0."""
|
||||
if temperature == 0.0:
|
||||
return torch.argmax(logits[-1, :], dim=-1).item()
|
||||
probs = torch.softmax(logits * (1.0 / temperature), dim=-1)
|
||||
return torch.multinomial(probs[-1, :], num_samples=1).item()
|
||||
|
||||
@torch.inference_mode()
|
||||
def infer_next_token(
|
||||
tokens: list[int],
|
||||
temperature: float = DEFAULT_TEMPERATURE,
|
||||
new_request: bool = False,
|
||||
) -> int:
|
||||
nonlocal tokens_so_far
|
||||
tokens_so_far = lcp(tokens_so_far, tokens)
|
||||
for cache in caches:
|
||||
cache.truncate(len(tokens_so_far))
|
||||
all_tokens = tokens # for pdb
|
||||
tokens = tokens[len(tokens_so_far) :]
|
||||
|
||||
if len(tokens) > 1:
|
||||
model.prefill(
|
||||
torch.as_tensor(tokens[:-1], dtype=torch.int32, device=device)[None, :],
|
||||
caches,
|
||||
)
|
||||
|
||||
if len(tokens) == 0:
|
||||
breakpoint()
|
||||
|
||||
input_token[-1] = tokens[-1]
|
||||
graph.replay()
|
||||
|
||||
# decide next token on rank‑0
|
||||
next_tok = sample_next_token(logits, temperature=temperature)
|
||||
|
||||
return next_tok
|
||||
|
||||
return infer_next_token
|
||||
|
||||
|
||||
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
|
||||
model, device = load_model(checkpoint)
|
||||
infer_next_token = get_infer_next_token(model, device)
|
||||
return infer_next_token
|
||||
Reference in New Issue
Block a user