Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
"""
|
|
NOTE: this is not the most efficient way to use vLLM. It's a simple implementation that infers
|
|
one token at a time to mimic the behavior of the Triton implementation.
|
|
"""
|
|
|
|
import os
|
|
from typing import Callable, List, Optional
|
|
|
|
# vLLM imports
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.inputs import TokensPrompt
|
|
|
|
DEFAULT_TEMPERATURE = 0.0
|
|
TP = os.environ.get("TP", 2)
|
|
|
|
def load_model(checkpoint: str):
|
|
"""
|
|
Create the vLLM engine. We enable prefix caching so repeated prefixes
|
|
across calls can reuse KV cache for faster prefill.
|
|
"""
|
|
|
|
llm = LLM(
|
|
model=checkpoint,
|
|
tensor_parallel_size=TP, # set >1 if you want TP across GPUs
|
|
enable_prefix_caching=True, # reuse KV for shared prefixes
|
|
disable_log_stats=True, # uncomment to quiet logs
|
|
)
|
|
|
|
return llm
|
|
|
|
|
|
def get_infer_next_token(llm: LLM):
|
|
"""
|
|
Return a callable with the same shape as your original:
|
|
infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int
|
|
|
|
Implementation detail:
|
|
- We issue a single-token generation with TokensPrompt(prompt_token_ids=tokens).
|
|
- vLLM handles sampling (temperature=0 => greedy).
|
|
- With enable_prefix_caching=True, the shared prefix prefill can be reused
|
|
across calls that share the same prefix.
|
|
"""
|
|
|
|
# Maintain compatibility with your previous closure signature.
|
|
def infer_next_token(
|
|
tokens: List[int],
|
|
temperature: float = DEFAULT_TEMPERATURE,
|
|
new_request: bool = False, # kept for interface compatibility; unused here
|
|
) -> int:
|
|
if not tokens:
|
|
raise ValueError("tokens must contain at least one input token id")
|
|
|
|
sampling = SamplingParams(
|
|
temperature=float(temperature),
|
|
max_tokens=1, # we only want the next token
|
|
n=1, # single continuation
|
|
# You can expose/enable more controls here (top_p, top_k, etc.)
|
|
)
|
|
|
|
# Provide token IDs directly (no re-tokenization).
|
|
outputs = llm.generate(
|
|
TokensPrompt(prompt_token_ids=tokens),
|
|
sampling_params=sampling,
|
|
)
|
|
|
|
if not outputs or not outputs[0].outputs:
|
|
raise RuntimeError("vLLM returned empty outputs")
|
|
|
|
gen = outputs[0].outputs[0]
|
|
if not gen.token_ids:
|
|
# If the model immediately finished (e.g., EOS), decide how you'd like
|
|
# to signal that. Here we raise; you could also return an EOS id.
|
|
raise RuntimeError("No next token was generated (possibly EOS).")
|
|
|
|
next_tok = int(gen.token_ids[0])
|
|
return next_tok
|
|
|
|
return infer_next_token
|
|
|
|
|
|
def setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]:
|
|
llm = load_model(checkpoint)
|
|
infer_next_token = get_infer_next_token(llm)
|
|
return infer_next_token
|