Initial commit

Co-authored-by: Zhuohan Li <zhuohan@openai.com>
Co-authored-by: Maratyszcza <marat@openai.com>
Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
This commit is contained in:
Dominik Kundel
2025-08-05 08:19:49 -07:00
commit 243a1b0276
124 changed files with 20405 additions and 0 deletions

View File

@@ -0,0 +1,78 @@
"""Metal backend for :mod:`gpt_oss.responses_api`."""
from typing import Callable
from gpt_oss.metal import Context, Model
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
"""Load the Metal model and return an inference function."""
model = Model(checkpoint)
context = Context(model)
def lcp(cache: list[int], inp: list[int]) -> list[int]:
i = 0
max_len = min(len(cache), len(inp))
while i < max_len and cache[i] == inp[i]:
i += 1
return cache[:i]
tokens_so_far = []
def infer_next_token(
tokens: list[int], temperature: float = 0.0, new_request: bool = False
) -> int:
"""Infer next token using incremental LCP caching when possible."""
nonlocal tokens_so_far
# Fast path: first call or explicitly new request.
if new_request or not tokens_so_far:
context.reset()
for t in tokens:
context.append(t)
tokens_so_far = tokens.copy()
context.process()
return int(context.sample(temperature=temperature))
# Longest common prefix length
overlap = lcp(tokens_so_far, tokens)
ol = len(overlap)
prev_len = len(tokens_so_far)
cur_len = len(tokens)
diverged_midstream = (ol < prev_len) and (
ol < cur_len
) # mismatch not at the end
if diverged_midstream:
# safest: rebuild
context.reset()
for t in tokens:
context.append(t)
tokens_so_far = tokens.copy()
context.process()
return int(context.sample(temperature=temperature))
if cur_len > prev_len:
# pure extension (good for KV reuse)
extension = tokens[prev_len:]
for t in extension:
context.append(t)
tokens_so_far = tokens.copy()
context.process()
return int(context.sample(temperature=temperature))
if cur_len < prev_len:
# truncation/backspace; easiest correct behavior is rebuild
context.reset()
for t in tokens:
context.append(t)
tokens_so_far = tokens.copy()
context.process()
return int(context.sample(temperature=temperature))
# cur_len == prev_len and everything matches => no new tokens; just sample.
return int(context.sample(temperature=temperature))
return infer_next_token

View File

@@ -0,0 +1,192 @@
"""
NOTE: this is a stiched together implementation that uses Ollama for inference. It's primarily used
for testing and development. It does not leverage any prompt caching or other optimizations and
can therefore be slow between turns.
"""
import json
import threading
import time
from typing import Callable, Optional
import requests
from openai_harmony import load_harmony_encoding, HarmonyEncodingName
EOS_TOKEN = 200002 # only used on hard timeout
# Tunables
POLL_INTERVAL_S = 0.01 # 10ms between buffer checks
CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call
NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS
FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS
# Shared state
_token_buffer: list[int] = []
_buffer_lock = threading.Lock()
_stream_thread: Optional[threading.Thread] = None
_stream_done = threading.Event()
_stream_error: Optional[Exception] = None
_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens
_previous_request_tokens: list[int] = []
def lcp(cache: list[int], inp: list[int]) -> list[int]:
i = 0
max_len = min(len(cache), len(inp))
while i < max_len and cache[i] == inp[i]:
i += 1
return cache[:i]
def _now():
return time.monotonic()
def _touch_progress():
global _last_progress_ts
_last_progress_ts = _now()
def _reset_stream_state():
global _token_buffer, _stream_thread, _stream_error
with _buffer_lock:
_token_buffer = []
_stream_done.clear()
_stream_thread = None
_stream_error = None
_touch_progress()
def setup_model(checkpoint: str) -> Callable[[list[int], float, bool], int]:
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
model_name = checkpoint
def _start_stream(token_ids: list[int], temperature: float):
prompt_text = encoding.decode(token_ids)
def run():
nonlocal prompt_text, temperature
global _stream_error
global _previous_request_tokens
accum_text = ""
last_len = 0 # number of tokens already emitted
try:
url = "http://localhost:11434/api/generate"
context = None
if len(_previous_request_tokens) > 0:
context = _previous_request_tokens
# cache_hit = lcp(_previous_request_tokens, token_ids)
# if len(cache_hit) > 0:
# context = cache_hit
# print(f"Cache hit: {encoding.decode(context)}")
# prompt_text = encoding.decode(token_ids[len(context):])
payload = {
"model": model_name,
"prompt": prompt_text,
"stream": True,
"context": context,
"options": {"temperature": temperature},
}
with requests.post(url, json=payload, stream=True, timeout=60) as resp:
resp.raise_for_status()
for line in resp.iter_lines(decode_unicode=True):
if not line:
continue
obj = json.loads(line)
if isinstance(obj.get("response"), str):
accum_text += obj["response"]
toks = encoding.encode(accum_text, allowed_special="all")
if len(toks) > last_len:
new_toks = toks[last_len:]
with _buffer_lock:
_token_buffer.extend(new_toks)
last_len = len(toks)
_touch_progress()
if obj.get("done", False):
_token_buffer.append(EOS_TOKEN)
last_len = len(toks)
_touch_progress()
context = obj.get("context")
if context and len(context) > 0:
_previous_request_tokens = context
break
_stream_done.set()
except Exception as e:
_stream_error = e
_stream_done.set()
t = threading.Thread(target=run, name="ollama-stream", daemon=True)
t.start()
return t
def infer_next_token(
tokens: list[int], temperature: float = 0.0, new_request: bool = False
) -> int:
"""
- Starts a new Ollama stream on new_request.
- Forwards tokens as they arrive.
- Only emits EOS_TOKEN if we exceed an inactivity timeout.
"""
global _stream_thread
if new_request:
_reset_stream_state()
_stream_thread = _start_stream(token_ids=tokens, temperature=temperature)
# Wait for first byte within FIRST_BYTE_TIMEOUT_S (without emitting EOS early)
start = _now()
while _now() - start < FIRST_BYTE_TIMEOUT_S:
with _buffer_lock:
if _token_buffer:
tok = _token_buffer.pop(0)
_touch_progress()
return tok
if _stream_error is not None:
raise RuntimeError(f"Ollama stream error: {_stream_error!r}")
# If Ollama finished instantly with no output, continue loop until timeout
time.sleep(POLL_INTERVAL_S)
# Hard first-byte timeout -> emit EOS so the server can stop this request
return EOS_TOKEN
if _stream_error is not None:
raise RuntimeError(f"Ollama stream error: {_stream_error!r}")
# Normal path: wait up to CALL_MAX_WAIT_S for a token to arrive
wait_start = _now()
while _now() - wait_start < CALL_MAX_WAIT_S:
with _buffer_lock:
if _token_buffer:
tok = _token_buffer.pop(0)
_touch_progress()
return tok
# No token yet; if we've been idle too long overall, end with EOS
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
return EOS_TOKEN
time.sleep(POLL_INTERVAL_S)
# Still no token in this call slice. Do NOT send EOS unless we've timed out.
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
return EOS_TOKEN
# Tell caller to call us again; block minimally by returning *nothing new*.
# We must return an int; safest is to wait a tiny bit longer for a token.
# If still none, keep returning only after short waits. Avoid EOS here.
# One more short wait to reduce hot-looping:
time.sleep(POLL_INTERVAL_S)
with _buffer_lock:
if _token_buffer:
tok = _token_buffer.pop(0)
_touch_progress()
return tok
# As a last resort for this call slice, return EOS only on true inactivity timeout.
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
return EOS_TOKEN
# If we reach here, we still haven't got a token—ask the caller to call again soon.
# Return a harmless token that the server will replace/ignore if your interface supports it.
# If your interface does NOT allow a sentinel, keep the short-blocking behavior above.
return EOS_TOKEN if False else 0 # replace `0` with a PAD/NOOP token your server ignores
return infer_next_token

View File

@@ -0,0 +1,142 @@
import time
from typing import Callable
fake_tokens = [
200005,
35644,
200008,
23483,
316,
1199,
1114,
717,
170154,
13,
200007,
200006,
173781,
200005,
35644,
316,
28,
44580,
775,
170154,
464,
91,
542,
141043,
91,
29,
4108,
200008,
10848,
7693,
7534,
28499,
18826,
18583,
200012,
]
fake_tokens = [
200005,
35644,
200008,
1844,
31064,
25,
392,
4827,
382,
220,
17,
659,
220,
17,
16842,
12295,
81645,
13,
51441,
6052,
13,
200007,
200006,
173781,
200005,
17196,
200008,
17,
659,
220,
17,
314,
220,
19,
13,
9552,
238,
242,
200002,
]
# fake_tokens = [200005, 35644, 200008, 976, 1825, 31064, 25, 392, 25216, 29400, 290, 11122, 306, 52768, 2117, 16842, 1416, 1309, 316, 2281, 198, 68, 290, 2208, 11122, 13, 1416, 679, 261, 1114, 717, 170154, 484, 44390, 261, 5100, 1621, 26, 581, 1757, 2005, 198, 75, 480, 483, 5100, 392, 137956, 2117, 11, 13180, 4050, 7801, 4733, 290, 11122, 5377, 484, 290, 1114, 7377, 13, 1416, 1309, 260, 198, 78, 1199, 290, 1114, 4584, 364, 58369, 2421, 717, 170154, 483, 5100, 392, 137956, 2117, 11, 13180, 4050, 200007, 200006, 173781, 200005, 12606, 815, 260, 198, 78, 28, 117673, 3490]
# fake_tokens = [
# 198,
# 200005,
# 35644,
# 200008,
# 23483,
# 316,
# 1199,
# 1114,
# 717,
# 170154,
# 13,
# 200007,
# 200006,
# 173781,
# 200005,
# 12606,
# 815,
# 316,
# 32455,
# 106847,
# 316,
# 28,
# 44580,
# 775,
# 170154,
# 464,
# 91,
# 542,
# 141043,
# 91,
# 29,
# 4108,
# 200008,
# 10848,
# 7693,
# 7534,
# 28499,
# 18826,
# 18583,
# 200012,
# 198,
# ]
token_queue = fake_tokens.copy()
def stub_infer_next_token(
tokens: list[int], temperature: float = 0.0, new_request: bool = False
) -> int:
global token_queue
next_tok = token_queue.pop(0)
if len(token_queue) == 0:
token_queue = fake_tokens.copy()
time.sleep(0.1)
return next_tok
def setup_model(_checkpoint: str) -> Callable[[list[int], float], int]:
return stub_infer_next_token

View File

@@ -0,0 +1,102 @@
import datetime
import os
from typing import Callable
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.distributed as dist
from gpt_oss.triton.model import Cache, ModelConfig, Transformer
DEFAULT_TEMPERATURE = 0.0
CONTEXT = 16_384
CONCURRENT_SESSIONS = 1
rank = int(
os.environ.get("RANK", 0)
) # set this env var to another value to run on other GPUs
def load_model(checkpoint: str):
print(f"[{rank}] loading model...")
torch.cuda.set_device(rank)
torch.set_grad_enabled(False)
device = torch.device(f"cuda:{rank}")
# Load model
model = Transformer.from_checkpoint(checkpoint, device=device)
print(f"[{rank}] loaded")
return model, device
def get_infer_next_token(model, device):
caches = [
Cache(CONCURRENT_SESSIONS, CONTEXT, model.config.num_key_value_heads)
for _ in range(len(model.block))
]
# offsets = torch.zeros(CONCURRENT_SESSIONS, dtype=torch.int32, device=device) # TBD
input_token = torch.zeros(
1, dtype=torch.int32, device=device
) # add concurrent sessions support
tokens_so_far = []
model.prefill(torch.zeros(1, 4, dtype=torch.int32, device=device), caches)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
logits = model(input_token[None, :], caches=caches)[0]
def lcp(cache: list[int], inp: list[int]) -> list[int]:
i = 0
max_len = min(len(cache), len(inp))
while i < max_len and cache[i] == inp[i]:
i += 1
return cache[:i]
def sample_next_token(
logits: torch.Tensor, temperature: float = DEFAULT_TEMPERATURE
) -> int:
"""Executed only on rank 0."""
if temperature == 0.0:
return torch.argmax(logits[-1, :], dim=-1).item()
probs = torch.softmax(logits * (1.0 / temperature), dim=-1)
return torch.multinomial(probs[-1, :], num_samples=1).item()
@torch.inference_mode()
def infer_next_token(
tokens: list[int],
temperature: float = DEFAULT_TEMPERATURE,
new_request: bool = False,
) -> int:
nonlocal tokens_so_far
tokens_so_far = lcp(tokens_so_far, tokens)
for cache in caches:
cache.truncate(len(tokens_so_far))
all_tokens = tokens # for pdb
tokens = tokens[len(tokens_so_far) :]
if len(tokens) > 1:
model.prefill(
torch.as_tensor(tokens[:-1], dtype=torch.int32, device=device)[None, :],
caches,
)
if len(tokens) == 0:
breakpoint()
input_token[-1] = tokens[-1]
graph.replay()
# decide next token on rank0
next_tok = sample_next_token(logits, temperature=temperature)
return next_tok
return infer_next_token
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
model, device = load_model(checkpoint)
infer_next_token = get_infer_next_token(model, device)
return infer_next_token

View File

@@ -0,0 +1,84 @@
"""
NOTE: this is not the most efficient way to use vLLM. It's a simple implementation that infers
one token at a time to mimic the behavior of the Triton implementation.
"""
import os
from typing import Callable, List, Optional
# vLLM imports
from vllm import LLM, SamplingParams
from vllm.inputs import TokensPrompt
DEFAULT_TEMPERATURE = 0.0
TP = os.environ.get("TP", 2)
def load_model(checkpoint: str):
"""
Create the vLLM engine. We enable prefix caching so repeated prefixes
across calls can reuse KV cache for faster prefill.
"""
llm = LLM(
model=checkpoint,
tensor_parallel_size=TP, # set >1 if you want TP across GPUs
enable_prefix_caching=True, # reuse KV for shared prefixes
disable_log_stats=True, # uncomment to quiet logs
)
return llm
def get_infer_next_token(llm: LLM):
"""
Return a callable with the same shape as your original:
infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int
Implementation detail:
- We issue a single-token generation with TokensPrompt(prompt_token_ids=tokens).
- vLLM handles sampling (temperature=0 => greedy).
- With enable_prefix_caching=True, the shared prefix prefill can be reused
across calls that share the same prefix.
"""
# Maintain compatibility with your previous closure signature.
def infer_next_token(
tokens: List[int],
temperature: float = DEFAULT_TEMPERATURE,
new_request: bool = False, # kept for interface compatibility; unused here
) -> int:
if not tokens:
raise ValueError("tokens must contain at least one input token id")
sampling = SamplingParams(
temperature=float(temperature),
max_tokens=1, # we only want the next token
n=1, # single continuation
# You can expose/enable more controls here (top_p, top_k, etc.)
)
# Provide token IDs directly (no re-tokenization).
outputs = llm.generate(
TokensPrompt(prompt_token_ids=tokens),
sampling_params=sampling,
)
if not outputs or not outputs[0].outputs:
raise RuntimeError("vLLM returned empty outputs")
gen = outputs[0].outputs[0]
if not gen.token_ids:
# If the model immediately finished (e.g., EOS), decide how you'd like
# to signal that. Here we raise; you could also return an EOS id.
raise RuntimeError("No next token was generated (possibly EOS).")
next_tok = int(gen.token_ids[0])
return next_tok
return infer_next_token
def setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]:
llm = load_model(checkpoint)
infer_next_token = get_infer_next_token(llm)
return infer_next_token