Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
This commit is contained in:
0
gpt_oss/responses_api/inference/__init__.py
Normal file
0
gpt_oss/responses_api/inference/__init__.py
Normal file
78
gpt_oss/responses_api/inference/metal.py
Normal file
78
gpt_oss/responses_api/inference/metal.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Metal backend for :mod:`gpt_oss.responses_api`."""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from gpt_oss.metal import Context, Model
|
||||
|
||||
|
||||
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
|
||||
"""Load the Metal model and return an inference function."""
|
||||
|
||||
model = Model(checkpoint)
|
||||
context = Context(model)
|
||||
|
||||
def lcp(cache: list[int], inp: list[int]) -> list[int]:
|
||||
i = 0
|
||||
max_len = min(len(cache), len(inp))
|
||||
while i < max_len and cache[i] == inp[i]:
|
||||
i += 1
|
||||
return cache[:i]
|
||||
|
||||
tokens_so_far = []
|
||||
|
||||
def infer_next_token(
|
||||
tokens: list[int], temperature: float = 0.0, new_request: bool = False
|
||||
) -> int:
|
||||
"""Infer next token using incremental LCP caching when possible."""
|
||||
nonlocal tokens_so_far
|
||||
|
||||
# Fast path: first call or explicitly new request.
|
||||
if new_request or not tokens_so_far:
|
||||
context.reset()
|
||||
for t in tokens:
|
||||
context.append(t)
|
||||
tokens_so_far = tokens.copy()
|
||||
context.process()
|
||||
return int(context.sample(temperature=temperature))
|
||||
|
||||
# Longest common prefix length
|
||||
overlap = lcp(tokens_so_far, tokens)
|
||||
ol = len(overlap)
|
||||
prev_len = len(tokens_so_far)
|
||||
cur_len = len(tokens)
|
||||
|
||||
diverged_midstream = (ol < prev_len) and (
|
||||
ol < cur_len
|
||||
) # mismatch not at the end
|
||||
|
||||
if diverged_midstream:
|
||||
# safest: rebuild
|
||||
context.reset()
|
||||
for t in tokens:
|
||||
context.append(t)
|
||||
tokens_so_far = tokens.copy()
|
||||
context.process()
|
||||
return int(context.sample(temperature=temperature))
|
||||
|
||||
if cur_len > prev_len:
|
||||
# pure extension (good for KV reuse)
|
||||
extension = tokens[prev_len:]
|
||||
for t in extension:
|
||||
context.append(t)
|
||||
tokens_so_far = tokens.copy()
|
||||
context.process()
|
||||
return int(context.sample(temperature=temperature))
|
||||
|
||||
if cur_len < prev_len:
|
||||
# truncation/backspace; easiest correct behavior is rebuild
|
||||
context.reset()
|
||||
for t in tokens:
|
||||
context.append(t)
|
||||
tokens_so_far = tokens.copy()
|
||||
context.process()
|
||||
return int(context.sample(temperature=temperature))
|
||||
|
||||
# cur_len == prev_len and everything matches => no new tokens; just sample.
|
||||
return int(context.sample(temperature=temperature))
|
||||
|
||||
return infer_next_token
|
||||
192
gpt_oss/responses_api/inference/ollama.py
Normal file
192
gpt_oss/responses_api/inference/ollama.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
NOTE: this is a stiched together implementation that uses Ollama for inference. It's primarily used
|
||||
for testing and development. It does not leverage any prompt caching or other optimizations and
|
||||
can therefore be slow between turns.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
import requests
|
||||
|
||||
from openai_harmony import load_harmony_encoding, HarmonyEncodingName
|
||||
|
||||
EOS_TOKEN = 200002 # only used on hard timeout
|
||||
|
||||
# Tunables
|
||||
POLL_INTERVAL_S = 0.01 # 10ms between buffer checks
|
||||
CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call
|
||||
NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS
|
||||
FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS
|
||||
|
||||
# Shared state
|
||||
_token_buffer: list[int] = []
|
||||
_buffer_lock = threading.Lock()
|
||||
_stream_thread: Optional[threading.Thread] = None
|
||||
_stream_done = threading.Event()
|
||||
_stream_error: Optional[Exception] = None
|
||||
_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens
|
||||
_previous_request_tokens: list[int] = []
|
||||
|
||||
def lcp(cache: list[int], inp: list[int]) -> list[int]:
|
||||
i = 0
|
||||
max_len = min(len(cache), len(inp))
|
||||
while i < max_len and cache[i] == inp[i]:
|
||||
i += 1
|
||||
return cache[:i]
|
||||
|
||||
def _now():
|
||||
return time.monotonic()
|
||||
|
||||
def _touch_progress():
|
||||
global _last_progress_ts
|
||||
_last_progress_ts = _now()
|
||||
|
||||
def _reset_stream_state():
|
||||
global _token_buffer, _stream_thread, _stream_error
|
||||
with _buffer_lock:
|
||||
_token_buffer = []
|
||||
_stream_done.clear()
|
||||
_stream_thread = None
|
||||
_stream_error = None
|
||||
_touch_progress()
|
||||
|
||||
def setup_model(checkpoint: str) -> Callable[[list[int], float, bool], int]:
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
model_name = checkpoint
|
||||
|
||||
def _start_stream(token_ids: list[int], temperature: float):
|
||||
prompt_text = encoding.decode(token_ids)
|
||||
def run():
|
||||
nonlocal prompt_text, temperature
|
||||
global _stream_error
|
||||
global _previous_request_tokens
|
||||
|
||||
accum_text = ""
|
||||
last_len = 0 # number of tokens already emitted
|
||||
|
||||
try:
|
||||
url = "http://localhost:11434/api/generate"
|
||||
context = None
|
||||
if len(_previous_request_tokens) > 0:
|
||||
context = _previous_request_tokens
|
||||
# cache_hit = lcp(_previous_request_tokens, token_ids)
|
||||
# if len(cache_hit) > 0:
|
||||
# context = cache_hit
|
||||
# print(f"Cache hit: {encoding.decode(context)}")
|
||||
# prompt_text = encoding.decode(token_ids[len(context):])
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"prompt": prompt_text,
|
||||
"stream": True,
|
||||
"context": context,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
|
||||
with requests.post(url, json=payload, stream=True, timeout=60) as resp:
|
||||
resp.raise_for_status()
|
||||
for line in resp.iter_lines(decode_unicode=True):
|
||||
if not line:
|
||||
continue
|
||||
obj = json.loads(line)
|
||||
|
||||
if isinstance(obj.get("response"), str):
|
||||
accum_text += obj["response"]
|
||||
toks = encoding.encode(accum_text, allowed_special="all")
|
||||
if len(toks) > last_len:
|
||||
new_toks = toks[last_len:]
|
||||
with _buffer_lock:
|
||||
_token_buffer.extend(new_toks)
|
||||
last_len = len(toks)
|
||||
_touch_progress()
|
||||
|
||||
if obj.get("done", False):
|
||||
_token_buffer.append(EOS_TOKEN)
|
||||
last_len = len(toks)
|
||||
_touch_progress()
|
||||
context = obj.get("context")
|
||||
if context and len(context) > 0:
|
||||
_previous_request_tokens = context
|
||||
break
|
||||
|
||||
_stream_done.set()
|
||||
|
||||
except Exception as e:
|
||||
_stream_error = e
|
||||
_stream_done.set()
|
||||
|
||||
t = threading.Thread(target=run, name="ollama-stream", daemon=True)
|
||||
t.start()
|
||||
return t
|
||||
|
||||
def infer_next_token(
|
||||
tokens: list[int], temperature: float = 0.0, new_request: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
- Starts a new Ollama stream on new_request.
|
||||
- Forwards tokens as they arrive.
|
||||
- Only emits EOS_TOKEN if we exceed an inactivity timeout.
|
||||
"""
|
||||
global _stream_thread
|
||||
|
||||
if new_request:
|
||||
_reset_stream_state()
|
||||
_stream_thread = _start_stream(token_ids=tokens, temperature=temperature)
|
||||
# Wait for first byte within FIRST_BYTE_TIMEOUT_S (without emitting EOS early)
|
||||
start = _now()
|
||||
while _now() - start < FIRST_BYTE_TIMEOUT_S:
|
||||
with _buffer_lock:
|
||||
if _token_buffer:
|
||||
tok = _token_buffer.pop(0)
|
||||
_touch_progress()
|
||||
return tok
|
||||
if _stream_error is not None:
|
||||
raise RuntimeError(f"Ollama stream error: {_stream_error!r}")
|
||||
# If Ollama finished instantly with no output, continue loop until timeout
|
||||
time.sleep(POLL_INTERVAL_S)
|
||||
# Hard first-byte timeout -> emit EOS so the server can stop this request
|
||||
return EOS_TOKEN
|
||||
|
||||
if _stream_error is not None:
|
||||
raise RuntimeError(f"Ollama stream error: {_stream_error!r}")
|
||||
|
||||
# Normal path: wait up to CALL_MAX_WAIT_S for a token to arrive
|
||||
wait_start = _now()
|
||||
while _now() - wait_start < CALL_MAX_WAIT_S:
|
||||
with _buffer_lock:
|
||||
if _token_buffer:
|
||||
tok = _token_buffer.pop(0)
|
||||
_touch_progress()
|
||||
return tok
|
||||
# No token yet; if we've been idle too long overall, end with EOS
|
||||
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
|
||||
return EOS_TOKEN
|
||||
time.sleep(POLL_INTERVAL_S)
|
||||
|
||||
# Still no token in this call slice. Do NOT send EOS unless we've timed out.
|
||||
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
|
||||
return EOS_TOKEN
|
||||
|
||||
# Tell caller to call us again; block minimally by returning *nothing new*.
|
||||
# We must return an int; safest is to wait a tiny bit longer for a token.
|
||||
# If still none, keep returning only after short waits. Avoid EOS here.
|
||||
# One more short wait to reduce hot-looping:
|
||||
time.sleep(POLL_INTERVAL_S)
|
||||
with _buffer_lock:
|
||||
if _token_buffer:
|
||||
tok = _token_buffer.pop(0)
|
||||
_touch_progress()
|
||||
return tok
|
||||
|
||||
# As a last resort for this call slice, return EOS only on true inactivity timeout.
|
||||
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
|
||||
return EOS_TOKEN
|
||||
|
||||
# If we reach here, we still haven't got a token—ask the caller to call again soon.
|
||||
# Return a harmless token that the server will replace/ignore if your interface supports it.
|
||||
# If your interface does NOT allow a sentinel, keep the short-blocking behavior above.
|
||||
return EOS_TOKEN if False else 0 # replace `0` with a PAD/NOOP token your server ignores
|
||||
|
||||
return infer_next_token
|
||||
142
gpt_oss/responses_api/inference/stub.py
Normal file
142
gpt_oss/responses_api/inference/stub.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
fake_tokens = [
|
||||
200005,
|
||||
35644,
|
||||
200008,
|
||||
23483,
|
||||
316,
|
||||
1199,
|
||||
1114,
|
||||
717,
|
||||
170154,
|
||||
13,
|
||||
200007,
|
||||
200006,
|
||||
173781,
|
||||
200005,
|
||||
35644,
|
||||
316,
|
||||
28,
|
||||
44580,
|
||||
775,
|
||||
170154,
|
||||
464,
|
||||
91,
|
||||
542,
|
||||
141043,
|
||||
91,
|
||||
29,
|
||||
4108,
|
||||
200008,
|
||||
10848,
|
||||
7693,
|
||||
7534,
|
||||
28499,
|
||||
18826,
|
||||
18583,
|
||||
200012,
|
||||
]
|
||||
fake_tokens = [
|
||||
200005,
|
||||
35644,
|
||||
200008,
|
||||
1844,
|
||||
31064,
|
||||
25,
|
||||
392,
|
||||
4827,
|
||||
382,
|
||||
220,
|
||||
17,
|
||||
659,
|
||||
220,
|
||||
17,
|
||||
16842,
|
||||
12295,
|
||||
81645,
|
||||
13,
|
||||
51441,
|
||||
6052,
|
||||
13,
|
||||
200007,
|
||||
200006,
|
||||
173781,
|
||||
200005,
|
||||
17196,
|
||||
200008,
|
||||
17,
|
||||
659,
|
||||
220,
|
||||
17,
|
||||
314,
|
||||
220,
|
||||
19,
|
||||
13,
|
||||
9552,
|
||||
238,
|
||||
242,
|
||||
200002,
|
||||
]
|
||||
# fake_tokens = [200005, 35644, 200008, 976, 1825, 31064, 25, 392, 25216, 29400, 290, 11122, 306, 52768, 2117, 16842, 1416, 1309, 316, 2281, 198, 68, 290, 2208, 11122, 13, 1416, 679, 261, 1114, 717, 170154, 484, 44390, 261, 5100, 1621, 26, 581, 1757, 2005, 198, 75, 480, 483, 5100, 392, 137956, 2117, 11, 13180, 4050, 7801, 4733, 290, 11122, 5377, 484, 290, 1114, 7377, 13, 1416, 1309, 260, 198, 78, 1199, 290, 1114, 4584, 364, 58369, 2421, 717, 170154, 483, 5100, 392, 137956, 2117, 11, 13180, 4050, 200007, 200006, 173781, 200005, 12606, 815, 260, 198, 78, 28, 117673, 3490]
|
||||
# fake_tokens = [
|
||||
# 198,
|
||||
# 200005,
|
||||
# 35644,
|
||||
# 200008,
|
||||
# 23483,
|
||||
# 316,
|
||||
# 1199,
|
||||
# 1114,
|
||||
# 717,
|
||||
# 170154,
|
||||
# 13,
|
||||
# 200007,
|
||||
# 200006,
|
||||
# 173781,
|
||||
# 200005,
|
||||
# 12606,
|
||||
# 815,
|
||||
# 316,
|
||||
# 32455,
|
||||
# 106847,
|
||||
# 316,
|
||||
# 28,
|
||||
# 44580,
|
||||
# 775,
|
||||
# 170154,
|
||||
# 464,
|
||||
# 91,
|
||||
# 542,
|
||||
# 141043,
|
||||
# 91,
|
||||
# 29,
|
||||
# 4108,
|
||||
# 200008,
|
||||
# 10848,
|
||||
# 7693,
|
||||
# 7534,
|
||||
# 28499,
|
||||
# 18826,
|
||||
# 18583,
|
||||
# 200012,
|
||||
# 198,
|
||||
# ]
|
||||
|
||||
token_queue = fake_tokens.copy()
|
||||
|
||||
|
||||
def stub_infer_next_token(
|
||||
tokens: list[int], temperature: float = 0.0, new_request: bool = False
|
||||
) -> int:
|
||||
global token_queue
|
||||
next_tok = token_queue.pop(0)
|
||||
if len(token_queue) == 0:
|
||||
token_queue = fake_tokens.copy()
|
||||
time.sleep(0.1)
|
||||
return next_tok
|
||||
|
||||
|
||||
def setup_model(_checkpoint: str) -> Callable[[list[int], float], int]:
|
||||
return stub_infer_next_token
|
||||
102
gpt_oss/responses_api/inference/triton.py
Normal file
102
gpt_oss/responses_api/inference/triton.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import datetime
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from gpt_oss.triton.model import Cache, ModelConfig, Transformer
|
||||
|
||||
DEFAULT_TEMPERATURE = 0.0
|
||||
CONTEXT = 16_384
|
||||
CONCURRENT_SESSIONS = 1
|
||||
|
||||
rank = int(
|
||||
os.environ.get("RANK", 0)
|
||||
) # set this env var to another value to run on other GPUs
|
||||
|
||||
|
||||
def load_model(checkpoint: str):
|
||||
print(f"[{rank}] loading model...")
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
torch.set_grad_enabled(False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
# Load model
|
||||
model = Transformer.from_checkpoint(checkpoint, device=device)
|
||||
|
||||
print(f"[{rank}] loaded")
|
||||
return model, device
|
||||
|
||||
|
||||
def get_infer_next_token(model, device):
|
||||
caches = [
|
||||
Cache(CONCURRENT_SESSIONS, CONTEXT, model.config.num_key_value_heads)
|
||||
for _ in range(len(model.block))
|
||||
]
|
||||
# offsets = torch.zeros(CONCURRENT_SESSIONS, dtype=torch.int32, device=device) # TBD
|
||||
input_token = torch.zeros(
|
||||
1, dtype=torch.int32, device=device
|
||||
) # add concurrent sessions support
|
||||
tokens_so_far = []
|
||||
|
||||
model.prefill(torch.zeros(1, 4, dtype=torch.int32, device=device), caches)
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
logits = model(input_token[None, :], caches=caches)[0]
|
||||
|
||||
def lcp(cache: list[int], inp: list[int]) -> list[int]:
|
||||
i = 0
|
||||
max_len = min(len(cache), len(inp))
|
||||
while i < max_len and cache[i] == inp[i]:
|
||||
i += 1
|
||||
return cache[:i]
|
||||
|
||||
def sample_next_token(
|
||||
logits: torch.Tensor, temperature: float = DEFAULT_TEMPERATURE
|
||||
) -> int:
|
||||
"""Executed only on rank 0."""
|
||||
if temperature == 0.0:
|
||||
return torch.argmax(logits[-1, :], dim=-1).item()
|
||||
probs = torch.softmax(logits * (1.0 / temperature), dim=-1)
|
||||
return torch.multinomial(probs[-1, :], num_samples=1).item()
|
||||
|
||||
@torch.inference_mode()
|
||||
def infer_next_token(
|
||||
tokens: list[int],
|
||||
temperature: float = DEFAULT_TEMPERATURE,
|
||||
new_request: bool = False,
|
||||
) -> int:
|
||||
nonlocal tokens_so_far
|
||||
tokens_so_far = lcp(tokens_so_far, tokens)
|
||||
for cache in caches:
|
||||
cache.truncate(len(tokens_so_far))
|
||||
all_tokens = tokens # for pdb
|
||||
tokens = tokens[len(tokens_so_far) :]
|
||||
|
||||
if len(tokens) > 1:
|
||||
model.prefill(
|
||||
torch.as_tensor(tokens[:-1], dtype=torch.int32, device=device)[None, :],
|
||||
caches,
|
||||
)
|
||||
|
||||
if len(tokens) == 0:
|
||||
breakpoint()
|
||||
|
||||
input_token[-1] = tokens[-1]
|
||||
graph.replay()
|
||||
|
||||
# decide next token on rank‑0
|
||||
next_tok = sample_next_token(logits, temperature=temperature)
|
||||
|
||||
return next_tok
|
||||
|
||||
return infer_next_token
|
||||
|
||||
|
||||
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
|
||||
model, device = load_model(checkpoint)
|
||||
infer_next_token = get_infer_next_token(model, device)
|
||||
return infer_next_token
|
||||
84
gpt_oss/responses_api/inference/vllm.py
Normal file
84
gpt_oss/responses_api/inference/vllm.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
NOTE: this is not the most efficient way to use vLLM. It's a simple implementation that infers
|
||||
one token at a time to mimic the behavior of the Triton implementation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
# vLLM imports
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import TokensPrompt
|
||||
|
||||
DEFAULT_TEMPERATURE = 0.0
|
||||
TP = os.environ.get("TP", 2)
|
||||
|
||||
def load_model(checkpoint: str):
|
||||
"""
|
||||
Create the vLLM engine. We enable prefix caching so repeated prefixes
|
||||
across calls can reuse KV cache for faster prefill.
|
||||
"""
|
||||
|
||||
llm = LLM(
|
||||
model=checkpoint,
|
||||
tensor_parallel_size=TP, # set >1 if you want TP across GPUs
|
||||
enable_prefix_caching=True, # reuse KV for shared prefixes
|
||||
disable_log_stats=True, # uncomment to quiet logs
|
||||
)
|
||||
|
||||
return llm
|
||||
|
||||
|
||||
def get_infer_next_token(llm: LLM):
|
||||
"""
|
||||
Return a callable with the same shape as your original:
|
||||
infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int
|
||||
|
||||
Implementation detail:
|
||||
- We issue a single-token generation with TokensPrompt(prompt_token_ids=tokens).
|
||||
- vLLM handles sampling (temperature=0 => greedy).
|
||||
- With enable_prefix_caching=True, the shared prefix prefill can be reused
|
||||
across calls that share the same prefix.
|
||||
"""
|
||||
|
||||
# Maintain compatibility with your previous closure signature.
|
||||
def infer_next_token(
|
||||
tokens: List[int],
|
||||
temperature: float = DEFAULT_TEMPERATURE,
|
||||
new_request: bool = False, # kept for interface compatibility; unused here
|
||||
) -> int:
|
||||
if not tokens:
|
||||
raise ValueError("tokens must contain at least one input token id")
|
||||
|
||||
sampling = SamplingParams(
|
||||
temperature=float(temperature),
|
||||
max_tokens=1, # we only want the next token
|
||||
n=1, # single continuation
|
||||
# You can expose/enable more controls here (top_p, top_k, etc.)
|
||||
)
|
||||
|
||||
# Provide token IDs directly (no re-tokenization).
|
||||
outputs = llm.generate(
|
||||
TokensPrompt(prompt_token_ids=tokens),
|
||||
sampling_params=sampling,
|
||||
)
|
||||
|
||||
if not outputs or not outputs[0].outputs:
|
||||
raise RuntimeError("vLLM returned empty outputs")
|
||||
|
||||
gen = outputs[0].outputs[0]
|
||||
if not gen.token_ids:
|
||||
# If the model immediately finished (e.g., EOS), decide how you'd like
|
||||
# to signal that. Here we raise; you could also return an EOS id.
|
||||
raise RuntimeError("No next token was generated (possibly EOS).")
|
||||
|
||||
next_tok = int(gen.token_ids[0])
|
||||
return next_tok
|
||||
|
||||
return infer_next_token
|
||||
|
||||
|
||||
def setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]:
|
||||
llm = load_model(checkpoint)
|
||||
infer_next_token = get_infer_next_token(llm)
|
||||
return infer_next_token
|
||||
Reference in New Issue
Block a user