Files
openharmony-mlx/gpt_oss/generate.py
Arthur Colle 92f5b57da3 Initial release: OpenHarmony-MLX - High-Performance Apple Silicon GPT-OSS Implementation
This is a complete rebranding and optimization of the original GPT-OSS codebase for Apple Silicon:

🚀 Features:
- Native MLX acceleration for M1/M2/M3/M4 chips
- Complete MLX implementation with Mixture of Experts (MoE)
- Memory-efficient quantization (4-bit MXFP4)
- Drop-in replacement APIs for existing backends
- Full tool integration (browser, python, apply_patch)
- Comprehensive build system with Metal kernels

📦 What's Included:
- gpt_oss/mlx_gpt_oss/ - Complete MLX implementation
- All original inference backends (torch, triton, metal, vllm)
- Command-line interfaces and Python APIs
- Developer tools and evaluation suite
- Updated branding and documentation

🍎 Apple Silicon Optimized:
- Up to 40 tokens/sec performance on Apple Silicon
- Run GPT-OSS-120b in 30GB with quantization
- Native Metal kernel acceleration
- Memory-mapped weight loading

🔧 Ready to Deploy:
- Updated package name to openharmony-mlx
- Comprehensive .gitignore for clean releases
- Updated README with Apple Silicon focus
- All build artifacts cleaned up

🧠 Generated with Claude Code

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-06 19:28:25 -04:00

86 lines
2.8 KiB
Python

# Model parallel inference
# Note: This script is for demonstration purposes only. It is not designed for production use.
# See gpt_oss.chat for a more complete example with the Harmony parser.
# torchrun --nproc-per-node=4 -m gpt_oss.generate -p "why did the chicken cross the road?" model/
import argparse
from gpt_oss.tokenizer import get_tokenizer
def main(args):
match args.backend:
case "torch":
from gpt_oss.torch.utils import init_distributed
from gpt_oss.torch.model import TokenGenerator as TorchGenerator
device = init_distributed()
generator = TorchGenerator(args.checkpoint, device=device)
case "triton":
from gpt_oss.torch.utils import init_distributed
from gpt_oss.triton.model import TokenGenerator as TritonGenerator
device = init_distributed()
generator = TritonGenerator(args.checkpoint, context=4096, device=device)
case "vllm":
from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator
generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2)
case "mlx":
from gpt_oss.mlx_gpt_oss.generate import TokenGenerator as MLXGenerator
generator = MLXGenerator(args.checkpoint)
case _:
raise ValueError(f"Invalid backend: {args.backend}")
tokenizer = get_tokenizer()
tokens = tokenizer.encode(args.prompt)
for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=args.limit, return_logprobs=True):
tokens.append(token)
decoded_token = tokenizer.decode([token])
print(
f"Generated token: {repr(decoded_token)}, logprob: {logprob}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Text generation example")
parser.add_argument(
"checkpoint",
metavar="FILE",
type=str,
help="Path to the SafeTensors checkpoint",
)
parser.add_argument(
"-p",
"--prompt",
metavar="PROMPT",
type=str,
default="How are you?",
help="LLM prompt",
)
parser.add_argument(
"-t",
"--temperature",
metavar="TEMP",
type=float,
default=0.0,
help="Sampling temperature",
)
parser.add_argument(
"-l",
"--limit",
metavar="LIMIT",
type=int,
default=0,
help="Limit on the number of tokens (0 to disable)",
)
parser.add_argument(
"-b",
"--backend",
metavar="BACKEND",
type=str,
default="torch",
choices=["triton", "torch", "vllm", "mlx"],
help="Inference backend",
)
args = parser.parse_args()
main(args)