Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
368 lines
14 KiB
Python
368 lines
14 KiB
Python
"""
|
|
Harmony chat with tools
|
|
"""
|
|
|
|
import atexit
|
|
import argparse
|
|
import asyncio
|
|
import datetime
|
|
import os
|
|
from pathlib import Path
|
|
|
|
try:
|
|
import gnureadline as readline
|
|
except ImportError:
|
|
import readline
|
|
|
|
import torch
|
|
import termcolor
|
|
|
|
from gpt_oss.tools import apply_patch
|
|
from gpt_oss.tools.simple_browser import SimpleBrowserTool
|
|
from gpt_oss.tools.simple_browser.backend import ExaBackend
|
|
from gpt_oss.tools.python_docker.docker_tool import PythonTool
|
|
|
|
from openai_harmony import (
|
|
Author,
|
|
Conversation,
|
|
DeveloperContent,
|
|
HarmonyEncodingName,
|
|
Message,
|
|
ReasoningEffort,
|
|
Role,
|
|
StreamableParser,
|
|
StreamState,
|
|
SystemContent,
|
|
TextContent,
|
|
ToolDescription,
|
|
load_harmony_encoding,
|
|
)
|
|
|
|
|
|
REASONING_EFFORT = {
|
|
"high": ReasoningEffort.HIGH,
|
|
"medium": ReasoningEffort.MEDIUM,
|
|
"low": ReasoningEffort.LOW,
|
|
}
|
|
|
|
|
|
def get_user_input():
|
|
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
|
if rank == 0:
|
|
user_input = input()
|
|
else:
|
|
user_input = ""
|
|
user_input_list = [user_input]
|
|
if torch.distributed.is_initialized():
|
|
torch.distributed.broadcast_object_list(user_input_list, 0)
|
|
return user_input_list[0]
|
|
|
|
|
|
def main(args):
|
|
match args.backend:
|
|
case "triton":
|
|
from gpt_oss.triton.model import TokenGenerator as TritonGenerator
|
|
from gpt_oss.torch.utils import init_distributed
|
|
device = init_distributed()
|
|
generator = TritonGenerator(args.checkpoint, args.context, device)
|
|
case "torch":
|
|
from gpt_oss.torch.model import TokenGenerator as TorchGenerator
|
|
from gpt_oss.torch.utils import init_distributed
|
|
device = init_distributed()
|
|
generator = TorchGenerator(args.checkpoint, device)
|
|
case "vllm":
|
|
from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator
|
|
generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2)
|
|
case _:
|
|
raise ValueError(f"Invalid backend: {args.backend}")
|
|
|
|
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
|
|
|
system_message_content = (
|
|
SystemContent.new()
|
|
.with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort])
|
|
.with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d"))
|
|
)
|
|
|
|
if args.browser:
|
|
backend = ExaBackend(
|
|
source="web",
|
|
)
|
|
browser_tool = SimpleBrowserTool(backend=backend)
|
|
system_message_content = system_message_content.with_tools(browser_tool.tool_config)
|
|
|
|
if args.python:
|
|
python_tool = PythonTool()
|
|
system_message_content = system_message_content.with_tools(python_tool.tool_config)
|
|
|
|
system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)
|
|
messages = [system_message]
|
|
|
|
if args.apply_patch:
|
|
apply_patch_instructions = Path(apply_patch.__file__).parent / "apply_patch.md"
|
|
developer_message = ""
|
|
if args.developer_message:
|
|
developer_message = args.developer_message + "\n"
|
|
developer_message += apply_patch_instructions.read_text()
|
|
developer_message_content = (
|
|
DeveloperContent.new()
|
|
.with_instructions(developer_message)
|
|
.with_function_tools([
|
|
ToolDescription.new(
|
|
"apply_patch",
|
|
"Patch a file",
|
|
parameters={
|
|
"type": "string",
|
|
"description": "Formatted patch code",
|
|
"default": "*** Begin Patch\n*** End Patch\n",
|
|
}
|
|
),
|
|
])
|
|
)
|
|
messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content))
|
|
elif args.developer_message:
|
|
developer_message_content = DeveloperContent.new().with_instructions(args.developer_message)
|
|
messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content))
|
|
|
|
if args.raw:
|
|
conversation = Conversation.from_messages(messages)
|
|
tokens = encoding.render_conversation(conversation)
|
|
system_message = encoding.decode(tokens)
|
|
print(system_message, flush=True, end="")
|
|
empty_user_message_tokens = encoding.render(Message.from_role_and_content(Role.USER, ""))
|
|
user_message_start = encoding.decode(empty_user_message_tokens[:-1])
|
|
user_message_end = encoding.decode(empty_user_message_tokens[-1:])
|
|
else:
|
|
# System message
|
|
print(termcolor.colored("System Message:", "cyan"), flush=True)
|
|
print(termcolor.colored("Model Identity:", "cyan"), system_message_content.model_identity, flush=True)
|
|
print(termcolor.colored("Reasoning Effort:", "cyan"), system_message_content.reasoning_effort, flush=True)
|
|
print(termcolor.colored("Conversation Start Date:", "cyan"), system_message_content.conversation_start_date, flush=True)
|
|
print(termcolor.colored("Knowledge Cutoff:", "cyan"), system_message_content.knowledge_cutoff, flush=True)
|
|
print(termcolor.colored("Browser Tool:", "cyan"), "Enabled" if args.browser else "Disabled", flush=True)
|
|
print(termcolor.colored("Python Tool:", "cyan"), "Enabled" if args.python else "Disabled", flush=True)
|
|
print(termcolor.colored("Apply Patch Function:", "cyan"), "Enabled" if args.apply_patch else "Disabled", flush=True)
|
|
# Developer message
|
|
print(termcolor.colored("Developer Message:", "yellow"), flush=True)
|
|
print(developer_message_content.instructions, flush=True)
|
|
|
|
# Print the system message and the user message start
|
|
MESSAGE_PADDING = 12
|
|
while True:
|
|
last_message = messages[-1]
|
|
if last_message.recipient is None:
|
|
if args.raw:
|
|
print(user_message_start, end="", flush=True)
|
|
user_message = get_user_input()
|
|
print(user_message_end, flush=True, end="")
|
|
else:
|
|
print(termcolor.colored("User:".ljust(MESSAGE_PADDING), "red"), flush=True)
|
|
user_message = get_user_input()
|
|
user_message = Message.from_role_and_content(Role.USER, user_message)
|
|
messages.append(user_message)
|
|
else:
|
|
# Tool or function call
|
|
if last_message.recipient.startswith("browser."):
|
|
assert args.browser, "Browser tool is not enabled"
|
|
tool_name = "Search"
|
|
async def run_tool():
|
|
results = []
|
|
async for msg in browser_tool.process(last_message):
|
|
results.append(msg)
|
|
return results
|
|
|
|
result = asyncio.run(run_tool())
|
|
messages += result
|
|
elif last_message.recipient.startswith("python"):
|
|
assert args.python, "Python tool is not enabled"
|
|
tool_name = "Python"
|
|
async def run_tool():
|
|
results = []
|
|
async for msg in python_tool.process(last_message):
|
|
results.append(msg)
|
|
return results
|
|
|
|
result = asyncio.run(run_tool())
|
|
messages += result
|
|
elif last_message.recipient == "functions.apply_patch":
|
|
assert args.apply_patch, "Apply patch tool is not enabled"
|
|
tool_name = "Apply Patch"
|
|
text = last_message.content[0].text
|
|
tool_output = None
|
|
|
|
if text.startswith("{"):
|
|
# this is json, try to extract the patch from it
|
|
import json
|
|
try:
|
|
some_dict = json.loads(text)
|
|
_, text = some_dict.popitem()
|
|
except Exception as e:
|
|
tool_output = f"Error parsing JSON: {e}"
|
|
|
|
if tool_output is None:
|
|
try:
|
|
tool_output = apply_patch.apply_patch(text)
|
|
except Exception as e:
|
|
tool_output = f"Error applying patch: {e}"
|
|
|
|
message = (
|
|
Message(
|
|
author=Author.new(Role.TOOL, last_message.recipient),
|
|
content=[TextContent(text=tool_output)]
|
|
)
|
|
.with_recipient("assistant")
|
|
)
|
|
if last_message.channel:
|
|
message = message.with_channel(last_message.channel)
|
|
|
|
result = [message]
|
|
messages += result
|
|
else:
|
|
raise ValueError(f"Unknown tool or function call: {last_message.recipient}")
|
|
# Print the tool or function call result
|
|
if args.raw:
|
|
rendered_result = encoding.render_conversation(Conversation.from_messages(result))
|
|
print(encoding.decode(rendered_result), flush=True, end="")
|
|
else:
|
|
print(termcolor.colored(f"{tool_name} output:".ljust(MESSAGE_PADDING), "magenta"), flush=True)
|
|
if tool_name == "Search" and not args.show_browser_results:
|
|
print("[Search results fed to the model]")
|
|
else:
|
|
print(result[0].content[0].text)
|
|
|
|
conversation = Conversation.from_messages(messages)
|
|
tokens = encoding.render_conversation_for_completion(
|
|
conversation, Role.ASSISTANT
|
|
)
|
|
|
|
if args.raw:
|
|
# Print the last two tokens, which are the start of the assistant message
|
|
print(encoding.decode(tokens[-2:]), flush=True, end="")
|
|
|
|
parser = StreamableParser(encoding, role=Role.ASSISTANT)
|
|
field_created = False
|
|
current_output_text = ""
|
|
output_text_delta_buffer = ""
|
|
for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions()):
|
|
parser.process(predicted_token)
|
|
if args.raw:
|
|
print(encoding.decode([predicted_token]), end="", flush=True)
|
|
continue
|
|
|
|
if parser.state == StreamState.EXPECT_START:
|
|
print("") # new line
|
|
field_created = False
|
|
|
|
if not parser.last_content_delta:
|
|
continue
|
|
|
|
if not field_created:
|
|
field_created = True
|
|
if parser.current_channel == "final":
|
|
print(termcolor.colored("Assistant:", "green"), flush=True)
|
|
elif parser.current_recipient is not None:
|
|
print(termcolor.colored(f"Tool call to {parser.current_recipient}:", "cyan"), flush=True)
|
|
else:
|
|
print(termcolor.colored("CoT:", "yellow"), flush=True)
|
|
|
|
should_send_output_text_delta = True
|
|
output_text_delta_buffer += parser.last_content_delta
|
|
if args.browser:
|
|
updated_output_text, _annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text + output_text_delta_buffer)
|
|
output_text_delta_buffer = updated_output_text[len(current_output_text):]
|
|
if has_partial_citations:
|
|
should_send_output_text_delta = False
|
|
if should_send_output_text_delta:
|
|
print(output_text_delta_buffer, end="", flush=True)
|
|
current_output_text += output_text_delta_buffer
|
|
output_text_delta_buffer = ""
|
|
|
|
messages += parser.messages
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Chat example",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
parser.add_argument(
|
|
"checkpoint",
|
|
metavar="FILE",
|
|
type=str,
|
|
help="Path to the SafeTensors checkpoint",
|
|
)
|
|
parser.add_argument(
|
|
"-r",
|
|
"--reasoning-effort",
|
|
metavar="REASONING_EFFORT",
|
|
type=str,
|
|
default="low",
|
|
choices=["high", "medium", "low"],
|
|
help="Reasoning effort",
|
|
)
|
|
parser.add_argument(
|
|
"-a",
|
|
"--apply-patch",
|
|
action="store_true",
|
|
help="Make apply_patch function available to the model",
|
|
)
|
|
parser.add_argument(
|
|
"-b",
|
|
"--browser",
|
|
default=False,
|
|
action="store_true",
|
|
help="Use browser tool",
|
|
)
|
|
parser.add_argument(
|
|
"--show-browser-results",
|
|
default=False,
|
|
action="store_true",
|
|
help="Show browser results",
|
|
)
|
|
parser.add_argument(
|
|
"-p",
|
|
"--python",
|
|
default=False,
|
|
action="store_true",
|
|
help="Use python tool",
|
|
)
|
|
parser.add_argument(
|
|
"--developer-message",
|
|
default="",
|
|
help="Developer message",
|
|
)
|
|
parser.add_argument(
|
|
"-c",
|
|
"--context",
|
|
metavar="CONTEXT",
|
|
type=int,
|
|
default=8192,
|
|
help="Max context length",
|
|
)
|
|
parser.add_argument(
|
|
"--raw",
|
|
default=False,
|
|
action="store_true",
|
|
help="Raw mode (does not render Harmony encoding)",
|
|
)
|
|
parser.add_argument(
|
|
"--backend",
|
|
type=str,
|
|
default="triton",
|
|
choices=["triton", "torch", "vllm"],
|
|
help="Inference backend",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if int(os.environ.get("WORLD_SIZE", 1)) == 1:
|
|
histfile = os.path.join(os.path.expanduser("~"), ".chat")
|
|
try:
|
|
readline.read_history_file(histfile)
|
|
readline.set_history_length(10000)
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
atexit.register(readline.write_history_file, histfile)
|
|
|
|
main(args)
|