Files
openharmony-mlx/gpt_oss/responses_api/inference/transformers.py
2025-08-05 10:02:16 -07:00

57 lines
1.7 KiB
Python

"""
NOTE: this is not the most efficient way to use transformers. It's a simple implementation that infers
one token at a time to mimic the behavior of the Triton implementation.
"""
import os
from typing import Callable, List
# Transformers imports
from transformers import AutoModelForCausalLM, PreTrainedModel
import torch
DEFAULT_TEMPERATURE = 0.0
TP = os.environ.get("TP", 2)
def load_model(checkpoint: str):
"""
Serve the model directly with the Auto API.
"""
model = AutoModelForCausalLM.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
)
return model
def get_infer_next_token(model: PreTrainedModel):
"""
Return a callable with the same shape as the original triton implementation:
infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int
Implementation detail:
- We issue a single-token generation with using model.generate
- generate handles sampling (temperature=0 => greedy, otherwise, sampling).
"""
def infer_next_token(
tokens: List[int],
temperature: float = DEFAULT_TEMPERATURE,
new_request: bool = False, # kept for interface compatibility; unused here
) -> int:
tokens = torch.tensor([tokens], dtype=torch.int64, device=model.device)
output = model.generate(tokens, max_new_tokens=1, do_sample=temperature != 0, temperature=temperature)
return output[0, -1].tolist()
return infer_next_token
def setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]:
model = load_model(checkpoint)
infer_next_token = get_infer_next_token(model)
return infer_next_token