Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
101 lines
3.6 KiB
Python
101 lines
3.6 KiB
Python
from abc import ABC, abstractmethod
|
|
from uuid import UUID, uuid4
|
|
from typing import AsyncIterator
|
|
|
|
from openai_harmony import (
|
|
Author,
|
|
Role,
|
|
Message,
|
|
TextContent,
|
|
)
|
|
|
|
|
|
def _maybe_update_inplace_and_validate_channel(
|
|
*, input_message: Message, tool_message: Message
|
|
) -> None:
|
|
# If the channel of a new message produced by tool is different from the originating message,
|
|
# we auto-set the new message's channel, if unset, or raise an error.
|
|
if tool_message.channel != input_message.channel:
|
|
if tool_message.channel is None:
|
|
tool_message.channel = input_message.channel
|
|
else:
|
|
raise ValueError(
|
|
f"Messages from tool should have the same channel ({tool_message.channel=}) as "
|
|
f"the triggering message ({input_message.channel=})."
|
|
)
|
|
|
|
|
|
class Tool(ABC):
|
|
"""
|
|
Something the model can call.
|
|
|
|
Tools expose APIs that are shown to the model in a syntax that the model
|
|
understands and knows how to call (from training data). Tools allow the
|
|
model to do things like run code, browse the web, etc.
|
|
"""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def name(self) -> str:
|
|
"""
|
|
An identifier for the tool. The convention is that a message will be routed to the tool
|
|
whose name matches its recipient field.
|
|
"""
|
|
|
|
@property
|
|
def output_channel_should_match_input_channel(self) -> bool:
|
|
"""
|
|
A flag which indicates whether the output channel of the tool should match the input channel.
|
|
"""
|
|
return True
|
|
|
|
async def process(self, message: Message) -> AsyncIterator[Message]:
|
|
"""
|
|
Process the message and return a list of messages to add to the conversation.
|
|
The input message should already be applicable to this tool.
|
|
Don't return the input message, just the new messages.
|
|
|
|
If implementing a tool that has to block while calling a function use `call_on_background_thread` to get a coroutine.
|
|
|
|
If you just want to test this use `evaluate_generator` to get the results.
|
|
|
|
Do not override this method; override `_process` below (to avoid interfering with tracing).
|
|
"""
|
|
async for m in self._process(message):
|
|
if self.output_channel_should_match_input_channel:
|
|
_maybe_update_inplace_and_validate_channel(input_message=message, tool_message=m)
|
|
yield m
|
|
|
|
@abstractmethod
|
|
async def _process(self, message: Message) -> AsyncIterator[Message]:
|
|
"""Override this method to provide the implementation of the tool."""
|
|
if False: # This is to convince the type checker that this is an async generator.
|
|
yield # type: ignore[unreachable]
|
|
_ = message # Stifle "unused argument" warning.
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def instruction(self) -> str:
|
|
"""
|
|
Describe the tool's functionality. For example, if it accepts python-formatted code,
|
|
provide documentation on the functions available.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def instruction_dict(self) -> dict[str, str]:
|
|
return {self.name: self.instruction()}
|
|
|
|
def error_message(
|
|
self, error_message: str, id: UUID | None = None, channel: str | None = None
|
|
) -> Message:
|
|
"""
|
|
Return an error message that's from this tool.
|
|
"""
|
|
return Message(
|
|
id=id if id else uuid4(),
|
|
author=Author(role=Role.TOOL, name=self.name),
|
|
content=TextContent(text=error_message), # TODO: Use SystemError instead
|
|
channel=channel,
|
|
).with_recipient("assistant")
|
|
|