Files
Dominik Kundel 243a1b0276 Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com>
Co-authored-by: Maratyszcza <marat@openai.com>
Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
2025-08-05 08:19:49 -07:00

79 lines
2.5 KiB
Python

"""Metal backend for :mod:`gpt_oss.responses_api`."""
from typing import Callable
from gpt_oss.metal import Context, Model
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
"""Load the Metal model and return an inference function."""
model = Model(checkpoint)
context = Context(model)
def lcp(cache: list[int], inp: list[int]) -> list[int]:
i = 0
max_len = min(len(cache), len(inp))
while i < max_len and cache[i] == inp[i]:
i += 1
return cache[:i]
tokens_so_far = []
def infer_next_token(
tokens: list[int], temperature: float = 0.0, new_request: bool = False
) -> int:
"""Infer next token using incremental LCP caching when possible."""
nonlocal tokens_so_far
# Fast path: first call or explicitly new request.
if new_request or not tokens_so_far:
context.reset()
for t in tokens:
context.append(t)
tokens_so_far = tokens.copy()
context.process()
return int(context.sample(temperature=temperature))
# Longest common prefix length
overlap = lcp(tokens_so_far, tokens)
ol = len(overlap)
prev_len = len(tokens_so_far)
cur_len = len(tokens)
diverged_midstream = (ol < prev_len) and (
ol < cur_len
) # mismatch not at the end
if diverged_midstream:
# safest: rebuild
context.reset()
for t in tokens:
context.append(t)
tokens_so_far = tokens.copy()
context.process()
return int(context.sample(temperature=temperature))
if cur_len > prev_len:
# pure extension (good for KV reuse)
extension = tokens[prev_len:]
for t in extension:
context.append(t)
tokens_so_far = tokens.copy()
context.process()
return int(context.sample(temperature=temperature))
if cur_len < prev_len:
# truncation/backspace; easiest correct behavior is rebuild
context.reset()
for t in tokens:
context.append(t)
tokens_so_far = tokens.copy()
context.process()
return int(context.sample(temperature=temperature))
# cur_len == prev_len and everything matches => no new tokens; just sample.
return int(context.sample(temperature=temperature))
return infer_next_token