Transformers responses API (#1)
This commit is contained in:
@@ -273,6 +273,7 @@ You can start this server with the following inference backends:
|
|||||||
- `metal` — uses the metal implementation on Apple Silicon only
|
- `metal` — uses the metal implementation on Apple Silicon only
|
||||||
- `ollama` — uses the Ollama /api/generate API as a inference solution
|
- `ollama` — uses the Ollama /api/generate API as a inference solution
|
||||||
- `vllm` — uses your installed vllm version to perform inference
|
- `vllm` — uses your installed vllm version to perform inference
|
||||||
|
- `transformers` — uses your installed transformers version to perform local inference
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
usage: python -m gpt_oss.responses_api.serve [-h] [--checkpoint FILE] [--port PORT] [--inference-backend BACKEND]
|
usage: python -m gpt_oss.responses_api.serve [-h] [--checkpoint FILE] [--port PORT] [--inference-backend BACKEND]
|
||||||
|
|||||||
56
gpt_oss/responses_api/inference/transformers.py
Normal file
56
gpt_oss/responses_api/inference/transformers.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""
|
||||||
|
NOTE: this is not the most efficient way to use transformers. It's a simple implementation that infers
|
||||||
|
one token at a time to mimic the behavior of the Triton implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
# Transformers imports
|
||||||
|
from transformers import AutoModelForCausalLM, PreTrainedModel
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_TEMPERATURE = 0.0
|
||||||
|
TP = os.environ.get("TP", 2)
|
||||||
|
|
||||||
|
def load_model(checkpoint: str):
|
||||||
|
"""
|
||||||
|
Serve the model directly with the Auto API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
checkpoint,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_infer_next_token(model: PreTrainedModel):
|
||||||
|
"""
|
||||||
|
Return a callable with the same shape as the original triton implementation:
|
||||||
|
infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int
|
||||||
|
|
||||||
|
Implementation detail:
|
||||||
|
- We issue a single-token generation with using model.generate
|
||||||
|
- generate handles sampling (temperature=0 => greedy, otherwise, sampling).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def infer_next_token(
|
||||||
|
tokens: List[int],
|
||||||
|
temperature: float = DEFAULT_TEMPERATURE,
|
||||||
|
new_request: bool = False, # kept for interface compatibility; unused here
|
||||||
|
) -> int:
|
||||||
|
tokens = torch.tensor([tokens], dtype=torch.int64, device=model.device)
|
||||||
|
output = model.generate(tokens, max_new_tokens=1, do_sample=temperature != 0, temperature=temperature)
|
||||||
|
return output[0, -1].tolist()
|
||||||
|
|
||||||
|
return infer_next_token
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]:
|
||||||
|
model = load_model(checkpoint)
|
||||||
|
infer_next_token = get_infer_next_token(model)
|
||||||
|
return infer_next_token
|
||||||
@@ -47,6 +47,8 @@ if __name__ == "__main__":
|
|||||||
from .inference.ollama import setup_model
|
from .inference.ollama import setup_model
|
||||||
elif args.inference_backend == "vllm":
|
elif args.inference_backend == "vllm":
|
||||||
from .inference.vllm import setup_model
|
from .inference.vllm import setup_model
|
||||||
|
elif args.inference_backend == "transformers":
|
||||||
|
from .inference.transformers import setup_model
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid inference backend: {args.inference_backend}")
|
raise ValueError(f"Invalid inference backend: {args.inference_backend}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user