Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
This commit is contained in:
47
gpt_oss/vllm/token_generator.py
Normal file
47
gpt_oss/vllm/token_generator.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from vllm import LLMEngine, EngineArgs, SamplingParams, TokensPrompt
|
||||
|
||||
|
||||
class TokenGenerator:
|
||||
def __init__(self, model_path: str, tensor_parallel_size: int = 1):
|
||||
args = EngineArgs(
|
||||
model=model_path,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
self.engine = LLMEngine.from_engine_args(args)
|
||||
self.request_id = 0
|
||||
|
||||
def generate(self,
|
||||
prompt_tokens: list[int],
|
||||
stop_tokens: list[int] | None = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: int = 0,
|
||||
return_logprobs: bool = False):
|
||||
if max_tokens == 0:
|
||||
max_tokens = None
|
||||
request_id = str(self.request_id)
|
||||
self.request_id += 1
|
||||
sampling_params = SamplingParams(temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_tokens,
|
||||
logprobs=0 if return_logprobs else None)
|
||||
prompt = TokensPrompt(prompt_token_ids=prompt_tokens)
|
||||
self.engine.add_request(request_id, prompt, sampling_params)
|
||||
last_token_id = []
|
||||
while self.engine.has_unfinished_requests():
|
||||
step_outputs = self.engine.step()
|
||||
output = step_outputs[0].outputs[0]
|
||||
token_ids = output.token_ids
|
||||
logprobs_list = output.logprobs if hasattr(output, "logprobs") else None
|
||||
new_token_ids = token_ids[len(last_token_id):]
|
||||
new_logprobs = logprobs_list[len(last_token_id):] if logprobs_list is not None else [None] * len(new_token_ids)
|
||||
for token_id, logprobs in zip(new_token_ids, new_logprobs):
|
||||
last_token_id.append(token_id)
|
||||
if return_logprobs:
|
||||
logprob_val = None
|
||||
if logprobs is not None and token_id in logprobs:
|
||||
logprob_val = logprobs[token_id].logprob
|
||||
yield (token_id, logprob_val)
|
||||
else:
|
||||
yield token_id
|
||||
if stop_tokens is not None and token_id in stop_tokens:
|
||||
break
|
||||
Reference in New Issue
Block a user