Files
openharmony-mlx/gpt_oss/tools/tool.py
Dominik Kundel 243a1b0276 Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com>
Co-authored-by: Maratyszcza <marat@openai.com>
Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
2025-08-05 08:19:49 -07:00

101 lines
3.6 KiB
Python

from abc import ABC, abstractmethod
from uuid import UUID, uuid4
from typing import AsyncIterator
from openai_harmony import (
Author,
Role,
Message,
TextContent,
)
def _maybe_update_inplace_and_validate_channel(
*, input_message: Message, tool_message: Message
) -> None:
# If the channel of a new message produced by tool is different from the originating message,
# we auto-set the new message's channel, if unset, or raise an error.
if tool_message.channel != input_message.channel:
if tool_message.channel is None:
tool_message.channel = input_message.channel
else:
raise ValueError(
f"Messages from tool should have the same channel ({tool_message.channel=}) as "
f"the triggering message ({input_message.channel=})."
)
class Tool(ABC):
"""
Something the model can call.
Tools expose APIs that are shown to the model in a syntax that the model
understands and knows how to call (from training data). Tools allow the
model to do things like run code, browse the web, etc.
"""
@property
@abstractmethod
def name(self) -> str:
"""
An identifier for the tool. The convention is that a message will be routed to the tool
whose name matches its recipient field.
"""
@property
def output_channel_should_match_input_channel(self) -> bool:
"""
A flag which indicates whether the output channel of the tool should match the input channel.
"""
return True
async def process(self, message: Message) -> AsyncIterator[Message]:
"""
Process the message and return a list of messages to add to the conversation.
The input message should already be applicable to this tool.
Don't return the input message, just the new messages.
If implementing a tool that has to block while calling a function use `call_on_background_thread` to get a coroutine.
If you just want to test this use `evaluate_generator` to get the results.
Do not override this method; override `_process` below (to avoid interfering with tracing).
"""
async for m in self._process(message):
if self.output_channel_should_match_input_channel:
_maybe_update_inplace_and_validate_channel(input_message=message, tool_message=m)
yield m
@abstractmethod
async def _process(self, message: Message) -> AsyncIterator[Message]:
"""Override this method to provide the implementation of the tool."""
if False: # This is to convince the type checker that this is an async generator.
yield # type: ignore[unreachable]
_ = message # Stifle "unused argument" warning.
raise NotImplementedError
@abstractmethod
def instruction(self) -> str:
"""
Describe the tool's functionality. For example, if it accepts python-formatted code,
provide documentation on the functions available.
"""
raise NotImplementedError
def instruction_dict(self) -> dict[str, str]:
return {self.name: self.instruction()}
def error_message(
self, error_message: str, id: UUID | None = None, channel: str | None = None
) -> Message:
"""
Return an error message that's from this tool.
"""
return Message(
id=id if id else uuid4(),
author=Author(role=Role.TOOL, name=self.name),
content=TextContent(text=error_message), # TODO: Use SystemError instead
channel=channel,
).with_recipient("assistant")