From a601a63cdc122f9a79b74b7c66859260b910c983 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 5 Aug 2025 19:02:16 +0200 Subject: [PATCH] Transformers responses API (#1) --- README.md | 1 + .../responses_api/inference/transformers.py | 56 +++++++++++++++++++ gpt_oss/responses_api/serve.py | 2 + 3 files changed, 59 insertions(+) create mode 100644 gpt_oss/responses_api/inference/transformers.py diff --git a/README.md b/README.md index f7f515c..38a7780 100644 --- a/README.md +++ b/README.md @@ -273,6 +273,7 @@ You can start this server with the following inference backends: - `metal` — uses the metal implementation on Apple Silicon only - `ollama` — uses the Ollama /api/generate API as a inference solution - `vllm` — uses your installed vllm version to perform inference +- `transformers` — uses your installed transformers version to perform local inference ```bash usage: python -m gpt_oss.responses_api.serve [-h] [--checkpoint FILE] [--port PORT] [--inference-backend BACKEND] diff --git a/gpt_oss/responses_api/inference/transformers.py b/gpt_oss/responses_api/inference/transformers.py new file mode 100644 index 0000000..c743707 --- /dev/null +++ b/gpt_oss/responses_api/inference/transformers.py @@ -0,0 +1,56 @@ +""" +NOTE: this is not the most efficient way to use transformers. It's a simple implementation that infers +one token at a time to mimic the behavior of the Triton implementation. +""" + +import os +from typing import Callable, List + +# Transformers imports +from transformers import AutoModelForCausalLM, PreTrainedModel +import torch + + +DEFAULT_TEMPERATURE = 0.0 +TP = os.environ.get("TP", 2) + +def load_model(checkpoint: str): + """ + Serve the model directly with the Auto API. + """ + + model = AutoModelForCausalLM.from_pretrained( + checkpoint, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + + return model + + +def get_infer_next_token(model: PreTrainedModel): + """ + Return a callable with the same shape as the original triton implementation: + infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int + + Implementation detail: + - We issue a single-token generation with using model.generate + - generate handles sampling (temperature=0 => greedy, otherwise, sampling). + """ + + def infer_next_token( + tokens: List[int], + temperature: float = DEFAULT_TEMPERATURE, + new_request: bool = False, # kept for interface compatibility; unused here + ) -> int: + tokens = torch.tensor([tokens], dtype=torch.int64, device=model.device) + output = model.generate(tokens, max_new_tokens=1, do_sample=temperature != 0, temperature=temperature) + return output[0, -1].tolist() + + return infer_next_token + + +def setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]: + model = load_model(checkpoint) + infer_next_token = get_infer_next_token(model) + return infer_next_token diff --git a/gpt_oss/responses_api/serve.py b/gpt_oss/responses_api/serve.py index 0d6974e..35fc3f4 100644 --- a/gpt_oss/responses_api/serve.py +++ b/gpt_oss/responses_api/serve.py @@ -47,6 +47,8 @@ if __name__ == "__main__": from .inference.ollama import setup_model elif args.inference_backend == "vllm": from .inference.vllm import setup_model + elif args.inference_backend == "transformers": + from .inference.transformers import setup_model else: raise ValueError(f"Invalid inference backend: {args.inference_backend}")