Files
openharmony-mlx/gpt_oss/metal/examples/generate.py
Dominik Kundel 243a1b0276 Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com>
Co-authored-by: Maratyszcza <marat@openai.com>
Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
2025-08-05 08:19:49 -07:00

36 lines
1.1 KiB
Python

#!/usr/bin/env python
import argparse
import sys
from datetime import date
from gpt_oss import Context, Model
parser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('model', metavar='PATH', type=str, help='Path to gpt-oss checkpoint')
parser.add_argument('-p', '--prompt', type=str, required=True, help='Prompt')
parser.add_argument('-l', '--limit', type=int, default=100, help='Number of tokens to generate')
parser.add_argument('--context-length', type=int, default=0, help='The maximum context length')
def main(args):
options = parser.parse_args(args)
model = Model(options.model)
context = Context(model, context_length=options.context_length)
context.append(options.prompt)
print(context.tokens)
prompt_tokens = context.num_tokens
tokenizer = model.tokenizer
while context.num_tokens - prompt_tokens < options.limit:
token = context.sample()
context.append(token)
print(str(tokenizer.decode(token), encoding="utf-8"), end='', flush=True)
if __name__ == '__main__':
main(sys.argv[1:])