57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
"""
|
|
NOTE: this is not the most efficient way to use transformers. It's a simple implementation that infers
|
|
one token at a time to mimic the behavior of the Triton implementation.
|
|
"""
|
|
|
|
import os
|
|
from typing import Callable, List
|
|
|
|
# Transformers imports
|
|
from transformers import AutoModelForCausalLM, PreTrainedModel
|
|
import torch
|
|
|
|
|
|
DEFAULT_TEMPERATURE = 0.0
|
|
TP = os.environ.get("TP", 2)
|
|
|
|
def load_model(checkpoint: str):
|
|
"""
|
|
Serve the model directly with the Auto API.
|
|
"""
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
checkpoint,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def get_infer_next_token(model: PreTrainedModel):
|
|
"""
|
|
Return a callable with the same shape as the original triton implementation:
|
|
infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int
|
|
|
|
Implementation detail:
|
|
- We issue a single-token generation with using model.generate
|
|
- generate handles sampling (temperature=0 => greedy, otherwise, sampling).
|
|
"""
|
|
|
|
def infer_next_token(
|
|
tokens: List[int],
|
|
temperature: float = DEFAULT_TEMPERATURE,
|
|
new_request: bool = False, # kept for interface compatibility; unused here
|
|
) -> int:
|
|
tokens = torch.tensor([tokens], dtype=torch.int64, device=model.device)
|
|
output = model.generate(tokens, max_new_tokens=1, do_sample=temperature != 0, temperature=temperature)
|
|
return output[0, -1].tolist()
|
|
|
|
return infer_next_token
|
|
|
|
|
|
def setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]:
|
|
model = load_model(checkpoint)
|
|
infer_next_token = get_infer_next_token(model)
|
|
return infer_next_token
|