Files
Dominik Kundel 243a1b0276 Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com>
Co-authored-by: Maratyszcza <marat@openai.com>
Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
2025-08-05 08:19:49 -07:00

916 lines
39 KiB
Python

import asyncio
import datetime
import uuid
from typing import Callable, Literal, Optional
import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from openai_harmony import (
Author,
Conversation,
DeveloperContent,
HarmonyEncoding,
Message,
ReasoningEffort,
Role,
StreamableParser,
StreamState,
SystemContent,
ToolDescription,
)
from gpt_oss.tools.simple_browser import SimpleBrowserTool
from gpt_oss.tools.simple_browser.backend import ExaBackend
from .events import (
ResponseCompletedEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseEvent,
ResponseOutputItemAdded,
ResponseOutputItemDone,
ResponseContentPartAdded,
ResponseContentPartDone,
ResponseOutputTextDone,
ResponseOutputTextDelta,
ResponseReasoningTextDone,
ResponseReasoningTextDelta,
ResponseWebSearchCallInProgress,
ResponseWebSearchCallSearching,
ResponseWebSearchCallCompleted,
ResponseOutputTextAnnotationAdded
)
from .types import (
UrlCitation,
Error,
FunctionCallItem,
Item,
ReasoningItem,
ReasoningTextContentItem,
ResponseObject,
ResponsesRequest,
TextContentItem,
Usage,
WebSearchCallItem,
WebSearchActionSearch,
WebSearchActionOpenPage,
WebSearchActionFind,
)
DEFAULT_TEMPERATURE = 0.0
def get_reasoning_effort(effort: Literal["low", "medium", "high"]) -> ReasoningEffort:
if effort == "low":
return ReasoningEffort.LOW
elif effort == "medium":
return ReasoningEffort.MEDIUM
elif effort == "high":
return ReasoningEffort.HIGH
def is_not_builtin_tool(recipient: str) -> bool:
return not recipient.startswith("browser.") and not recipient == "python" and not recipient == "assistant"
def create_api_server(
infer_next_token: Callable[[list[int], float], int], encoding: HarmonyEncoding
) -> FastAPI:
app = FastAPI()
responses_store: dict[str, tuple[ResponsesRequest, ResponseObject]] = {}
def generate_response(
input_tokens: list[int],
output_tokens: list[int],
request_body: ResponsesRequest,
debug_mode: bool = False,
function_call_ids: Optional[list[tuple[str, str]]] = None,
response_id: Optional[str] = None,
previous_response_id: Optional[str] = None,
browser_tool: Optional[SimpleBrowserTool] = None,
browser_call_ids: Optional[list[str]] = None,
) -> ResponseObject:
output = []
error = None
if len(output_tokens) > 0:
if debug_mode:
try:
entries = encoding.parse_messages_from_completion_tokens(
output_tokens, Role.ASSISTANT
)
except Exception as e:
print(f"Error parsing tokens: {e}")
error = Error(
code="invalid_function_call",
message=f"{e}",
)
entries = []
else:
entries = encoding.parse_messages_from_completion_tokens(
output_tokens, Role.ASSISTANT
)
fc_index = 0
browser_tool_index = 0
for entry in entries:
entry_dict = entry.to_dict()
if len(entry_dict.get("recipient", "")) > 0 and is_not_builtin_tool(entry_dict["recipient"]):
call = entry_dict["content"][0]
arguments = call["text"]
name = entry_dict["recipient"]
if name.startswith("functions."):
name = name[len("functions.") :]
if function_call_ids and fc_index < len(function_call_ids):
fc_id, call_id = function_call_ids[fc_index]
else:
fc_id, call_id = (
f"fc_{uuid.uuid4().hex}",
f"call_{uuid.uuid4().hex}",
)
fc_index += 1
output.append(
FunctionCallItem(
type="function_call",
name=name,
arguments=arguments,
id=fc_id,
call_id=call_id,
)
)
elif len(entry_dict.get("recipient", "")) > 0 and entry_dict["recipient"].startswith("browser.") and browser_tool is not None:
# Mirror event-based creation of WebSearchCallItems when the browser tool is invoked
name = entry_dict["recipient"]
call = entry_dict["content"][0]
arguments = call["text"]
function_name = name[len("browser."):]
# Reconstruct a Message for argument parsing
tool_msg = (
Message.from_role_and_content(Role.ASSISTANT, arguments)
.with_recipient(name)
.with_channel("analysis")
)
action = None
try:
parsed_args = browser_tool.process_arguments(tool_msg)
if function_name == "search":
action = WebSearchActionSearch(
type="search",
query=parsed_args["query"],
)
elif function_name == "open":
action = WebSearchActionOpenPage(
type="open_page",
url=parsed_args["url"],
)
elif function_name == "find":
action = WebSearchActionFind(
type="find",
pattern=parsed_args["pattern"],
url=parsed_args["url"],
)
except Exception as e:
print(f"Error processing browser tool arguments: {e}")
action = None
if action is not None:
if browser_call_ids and browser_tool_index < len(browser_call_ids):
web_search_call_id = browser_call_ids[browser_tool_index]
else:
web_search_call_id = f"ws_{uuid.uuid4().hex}"
browser_tool_index += 1
output.append(
WebSearchCallItem(
type="web_search_call",
id=web_search_call_id,
action=action,
)
)
elif entry_dict["channel"] == "final":
content = []
for content_entry in entry_dict["content"]:
if browser_tool:
text_content, annotation_entries, _has_partial_citations = browser_tool.normalize_citations(content_entry["text"])
annotations = [UrlCitation(**a) for a in annotation_entries]
else:
text_content = content_entry["text"]
annotations = []
content.append(
TextContentItem(
type="output_text",
text=text_content,
annotations=annotations,
)
)
output.append(
Item(
type="message",
role="assistant",
content=content,
status="completed",
)
)
elif entry_dict["channel"] == "analysis":
summary = []
content = [
ReasoningTextContentItem(
type="reasoning_text",
text=entry["text"],
)
for entry in entry_dict["content"]
]
output.append(
ReasoningItem(
type="reasoning",
summary=summary,
content=content,
)
)
else:
output = []
usage = (
Usage(
input_tokens=len(input_tokens),
output_tokens=len(output_tokens),
total_tokens=len(input_tokens) + len(output_tokens),
)
if len(output_tokens) > 0
else None
)
try:
debug_str = encoding.decode_utf8(input_tokens + output_tokens)
except Exception:
debug_str = input_tokens + output_tokens
try:
debug_input_str = encoding.decode_utf8(input_tokens)
except Exception:
debug_input_str = input_tokens
try:
debug_output_str = encoding.decode_utf8(output_tokens)
except Exception:
debug_output_str = output_tokens
metadata = (
{
"__debug": debug_str,
"__debug_input": debug_input_str,
"__debug_output": debug_output_str,
}
if debug_mode
else {}
)
return ResponseObject(
created_at=int(datetime.datetime.now().timestamp()),
status="completed",
output=output,
text={"format": {"type": "text"}},
usage=usage,
max_output_tokens=request_body.max_output_tokens,
error=error,
metadata=metadata,
id=response_id,
previous_response_id=previous_response_id,
)
class StreamResponsesEvents:
initial_tokens: list[int]
tokens: list[int]
output_tokens: list[int]
output_text: str
request_body: ResponsesRequest
request: Request
sequence_number: int
def __init__(
self,
initial_tokens,
request_body: ResponsesRequest,
as_sse: bool = False,
request: Optional[Request] = None,
response_id: Optional[str] = None,
store_callback: Optional[
Callable[[str, ResponsesRequest, ResponseObject], None]
] = None,
browser_tool: Optional[SimpleBrowserTool] = None,
):
self.initial_tokens = initial_tokens
self.tokens = initial_tokens.copy()
self.output_tokens = []
self.output_text = ""
self.request_body = request_body
self.parser = StreamableParser(encoding, role=Role.ASSISTANT)
self.as_sse = as_sse
self.debug_mode = request_body.metadata.get(
"__debug", False
) # we use this for demo purposes
# Set temperature for this stream, fallback to DEFAULT_TEMPERATURE if not set
self.temperature = (
request_body.temperature
if request_body.temperature is not None
else DEFAULT_TEMPERATURE
)
self.request = request
self.sequence_number = 0
self.function_call_ids: list[tuple[str, str]] = []
self.response_id = response_id
self.store_callback = store_callback
self.new_request = True
self.browser_tool = browser_tool
self.use_browser_tool = browser_tool is not None
self.browser_call_ids: list[str] = []
def _send_event(self, event: ResponseEvent):
event.sequence_number = self.sequence_number
self.sequence_number += 1
if self.as_sse:
return f"event: {event.type}\ndata: {event.model_dump_json(indent=None)}\n\n"
else:
return event
async def run(self):
browser_tool = self.browser_tool
self.new_request = True
initial_response = generate_response(
self.initial_tokens,
self.output_tokens,
self.request_body,
function_call_ids=self.function_call_ids,
response_id=self.response_id,
previous_response_id=self.request_body.previous_response_id,
)
initial_response.status = "in_progress"
yield self._send_event(
ResponseCreatedEvent(
type="response.created",
response=initial_response,
)
)
yield self._send_event(
ResponseInProgressEvent(
type="response.in_progress",
response=initial_response,
)
)
current_content_index = (
0 # for this implementation we will always have one content item only
)
current_output_index = -1
sent_output_item_added = False
# we use this if the model outputs a citation to buffer until completed
output_delta_buffer = ""
# we use this to track the current output text content for things like providing the right indices in citations
current_output_text_content = ""
current_annotations = []
while True:
# Check for client disconnect
if self.request is not None and await self.request.is_disconnected():
print("Client disconnected, stopping token generation.")
break
next_tok = infer_next_token(
self.tokens,
temperature=self.temperature,
new_request=self.new_request,
)
self.new_request = False
self.tokens.append(next_tok)
try:
self.parser.process(next_tok)
except Exception as e:
pass
if self.parser.state == StreamState.EXPECT_START:
current_output_index += 1
sent_output_item_added = False
if len(self.parser.messages) > 0:
previous_item = self.parser.messages[-1]
if previous_item.recipient is not None:
recipient = previous_item.recipient
if (
not recipient.startswith("browser.")
and not recipient == "python"
):
fc_id = f"fc_{uuid.uuid4().hex}"
call_id = f"call_{uuid.uuid4().hex}"
self.function_call_ids.append((fc_id, call_id))
yield self._send_event(
ResponseOutputItemDone(
type="response.output_item.done",
output_index=current_output_index,
item=FunctionCallItem(
type="function_call",
name=(
previous_item.recipient[
len("functions.") :
]
if previous_item.recipient.startswith(
"functions."
)
else previous_item.recipient
),
arguments=previous_item.content[0].text,
id=fc_id,
call_id=call_id,
),
)
)
if previous_item.channel == "analysis":
yield self._send_event(
ResponseReasoningTextDone(
type="response.reasoning_text.done",
output_index=current_output_index,
content_index=current_content_index,
text=previous_item.content[0].text,
)
)
yield self._send_event(
ResponseContentPartDone(
type="response.content_part.done",
output_index=current_output_index,
content_index=current_content_index,
part=ReasoningTextContentItem(
type="reasoning_text",
text=previous_item.content[0].text,
),
)
)
yield self._send_event(
ResponseOutputItemDone(
type="response.output_item.done",
output_index=current_output_index,
item=ReasoningItem(
type="reasoning",
summary=[],
content=[
ReasoningTextContentItem(
type="reasoning_text",
text=previous_item.content[0].text,
)
],
),
)
)
if previous_item.channel == "final":
annotations = [UrlCitation(**a) for a in current_annotations]
if browser_tool:
normalized_text, _annotations, _has_partial_citations = browser_tool.normalize_citations(previous_item.content[0].text)
else:
normalized_text = previous_item.content[0].text
annotations = []
text_content = TextContentItem(
type="output_text",
text=normalized_text,
annotations=annotations,
)
yield self._send_event(
ResponseOutputTextDone(
type="response.output_text.done",
output_index=current_output_index,
content_index=current_content_index,
text=normalized_text,
)
)
yield self._send_event(
ResponseContentPartDone(
type="response.content_part.done",
output_index=current_output_index,
content_index=current_content_index,
part=text_content,
)
)
yield self._send_event(
ResponseOutputItemDone(
type="response.output_item.done",
output_index=current_output_index,
item=Item(
type="message",
role="assistant",
content=[text_content],
),
)
)
current_annotations = []
current_output_text_content = ""
if (
self.parser.last_content_delta
and self.parser.current_channel == "final"
and self.parser.current_recipient is None
):
if not sent_output_item_added:
sent_output_item_added = True
yield self._send_event(
ResponseOutputItemAdded(
type="response.output_item.added",
output_index=current_output_index,
item=Item(type="message", role="assistant", content=[]),
)
)
yield self._send_event(
ResponseContentPartAdded(
type="response.content_part.added",
output_index=current_output_index,
content_index=current_content_index,
part=TextContentItem(type="output_text", text=""),
)
)
output_delta_buffer += self.parser.last_content_delta
should_send_output_text_delta = True
if browser_tool:
# we normalize on the full current text to get the right indices in citations
updated_output_text, annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text_content + output_delta_buffer)
# remove the current text to get back the delta but now normalized
output_delta_buffer = updated_output_text[len(current_output_text_content):]
# Filter annotations to only include those whose start_index is not already present in current_annotations
# this is to avoid sending duplicate annotations as multiple annotations can't be in the same place
existing_start_indices = {a["start_index"] for a in current_annotations}
new_annotations = [a for a in annotations if a["start_index"] not in existing_start_indices]
for a in new_annotations:
current_annotations.append(a)
citation = UrlCitation(**a)
yield self._send_event(
ResponseOutputTextAnnotationAdded(
type="response.output_text.annotation.added",
output_index=current_output_index,
content_index=current_content_index,
annotation_index=len(current_annotations),
annotation=citation,
)
)
if has_partial_citations:
should_send_output_text_delta = False
if should_send_output_text_delta:
yield self._send_event(
ResponseOutputTextDelta(
type="response.output_text.delta",
output_index=current_output_index,
content_index=current_content_index,
delta=output_delta_buffer,
)
)
current_output_text_content += output_delta_buffer
output_delta_buffer = ""
if (
self.parser.last_content_delta
and self.parser.current_channel == "analysis"
and self.parser.current_recipient is None
):
if not sent_output_item_added:
sent_output_item_added = True
yield self._send_event(
ResponseOutputItemAdded(
type="response.output_item.added",
output_index=current_output_index,
item=ReasoningItem(
type="reasoning", summary=[], content=[]
),
)
)
yield self._send_event(
ResponseContentPartAdded(
type="response.content_part.added",
output_index=current_output_index,
content_index=current_content_index,
part=ReasoningTextContentItem(type="reasoning_text", text=""),
)
)
yield self._send_event(
ResponseReasoningTextDelta(
type="response.reasoning_text.delta",
output_index=current_output_index,
content_index=current_content_index,
delta=self.parser.last_content_delta,
)
)
try:
# purely for debugging purposes
output_token_text = encoding.decode_utf8([next_tok])
self.output_text += output_token_text
print(output_token_text, end="", flush=True)
except RuntimeError:
pass
if next_tok in encoding.stop_tokens_for_assistant_actions():
if len(self.parser.messages) > 0:
last_message = self.parser.messages[-1]
if (
self.use_browser_tool
and last_message.recipient is not None
and last_message.recipient.startswith("browser.")
):
function_name = last_message.recipient[len("browser."):]
action = None
parsed_args = browser_tool.process_arguments(last_message)
if function_name == "search":
action = WebSearchActionSearch(
type="search",
query=parsed_args["query"],
)
elif function_name == "open":
action = WebSearchActionOpenPage(
type="open_page",
url=parsed_args["url"] if "url" in parsed_args else None,
)
elif function_name == "find":
action = WebSearchActionFind(
type="find",
pattern=parsed_args["pattern"],
url=parsed_args["url"] if "url" in parsed_args else None,
)
if action is not None:
web_search_call_id = f"ws_{uuid.uuid4().hex}"
self.browser_call_ids.append(web_search_call_id)
yield self._send_event(ResponseOutputItemAdded(
type="response.output_item.added",
output_index=current_output_index,
item=WebSearchCallItem(
type="web_search_call",
id=web_search_call_id,
action=action,
),
))
yield self._send_event(
ResponseWebSearchCallInProgress(
type="response.web_search_call.in_progress",
output_index=current_output_index,
id=web_search_call_id
)
)
async def run_tool():
results = []
async for msg in browser_tool.process(last_message):
results.append(msg)
return results
yield self._send_event(
ResponseWebSearchCallSearching(
type="response.web_search_call.searching",
output_index=current_output_index,
id=web_search_call_id,
)
)
result = await run_tool()
new_tokens = encoding.render_conversation_for_completion(
Conversation.from_messages(result), Role.ASSISTANT
)
print(encoding.decode_utf8(new_tokens))
self.output_tokens.append(next_tok)
self.tokens.append(encoding.encode('<|end|>', allowed_special="all")[0])
for token in new_tokens:
self.parser.process(token)
self.output_tokens.append(token)
self.tokens.append(token)
yield self._send_event(
ResponseWebSearchCallCompleted(
type="response.web_search_call.completed",
output_index=current_output_index,
id=web_search_call_id,
)
)
yield self._send_event(ResponseOutputItemDone(
type="response.output_item.done",
output_index=current_output_index,
item=WebSearchCallItem(
type="web_search_call",
id=web_search_call_id,
action=action,
),
))
current_output_index += 1
self.new_request = True
continue
else:
break
else:
raise ValueError("No messages to process")
if len(self.output_tokens) >= self.request_body.max_output_tokens:
break
# Adding in the end if we know we are not done
self.output_tokens.append(next_tok)
if self.request is None or not await self.request.is_disconnected():
response = generate_response(
self.initial_tokens,
self.output_tokens,
self.request_body,
debug_mode=self.debug_mode,
function_call_ids=self.function_call_ids,
response_id=self.response_id,
previous_response_id=self.request_body.previous_response_id,
browser_tool=self.browser_tool,
browser_call_ids=self.browser_call_ids,
)
if self.store_callback and self.request_body.store:
self.store_callback(self.response_id, self.request_body, response)
yield self._send_event(
ResponseCompletedEvent(
type="response.completed",
response=response,
)
)
@app.post("/v1/responses", response_model=ResponseObject)
async def generate(body: ResponsesRequest, request: Request):
print("request received")
use_browser_tool = any(
getattr(tool, "type", None) == "browser_search"
for tool in (body.tools or [])
)
if use_browser_tool:
backend = ExaBackend(
source="web",
)
browser_tool = SimpleBrowserTool(backend=backend)
else:
browser_tool = None
if body.previous_response_id:
prev = responses_store.get(body.previous_response_id)
if prev:
prev_req, prev_resp = prev
def _ensure_list(inp):
if isinstance(inp, str):
return [
Item(
type="message",
role="user",
content=[TextContentItem(type="input_text", text=inp)],
)
]
return list(inp)
merged_input = _ensure_list(prev_req.input) + list(prev_resp.output)
merged_input.extend(_ensure_list(body.input))
if body.instructions is None:
body.instructions = prev_req.instructions
body.input = merged_input
system_message_content = SystemContent.new().with_conversation_start_date(
datetime.datetime.now().strftime("%Y-%m-%d")
)
if body.reasoning is not None:
reasoning_effort = get_reasoning_effort(body.reasoning.effort)
system_message_content = system_message_content.with_reasoning_effort(reasoning_effort)
if use_browser_tool:
system_message_content = system_message_content.with_tools(browser_tool.tool_config)
system_message = Message.from_role_and_content(
Role.SYSTEM, system_message_content
)
developer_message_content = DeveloperContent.new().with_instructions(
body.instructions
)
tools = []
if body.tools:
for tool in body.tools:
if tool.type == "function":
has_functions = True
tools.append(
ToolDescription.new(
tool.name,
tool.description,
tool.parameters,
)
)
if len(tools) > 0:
developer_message_content = developer_message_content.with_function_tools(
tools
)
developer_message = Message.from_role_and_content(
Role.DEVELOPER, developer_message_content
)
messages = [system_message, developer_message]
if isinstance(body.input, str):
user_message = Message.from_role_and_content(Role.USER, body.input)
messages.append(user_message)
else:
is_last_message_function_call_output = (
len(body.input) > 0 and body.input[-1].type == "function_call_output"
)
function_call_map = {}
# Find the index of the last assistant message
last_assistant_idx = -1
for idx, item in enumerate(body.input):
if item.type == "message" and item.role == Role.ASSISTANT:
last_assistant_idx = idx
for idx, item in enumerate(body.input):
if item.type == "message":
# TODO: add system prompt handling
if isinstance(item.content, str):
messages.append(
Message.from_role_and_content(item.role, item.content)
)
else:
for content_item in item.content:
messages.append(
Message.from_role_and_content(item.role, content_item.text)
)
# add final channel to the last assistant message if it's from the assistant
if item.role == Role.ASSISTANT:
messages[-1] = messages[-1].with_channel("final")
elif item.type == "reasoning":
# Only include reasoning if it is after the last assistant message and we are handling a function call at the moment
if (
idx > last_assistant_idx
and is_last_message_function_call_output
):
for content_item in item.content:
messages.append(
Message.from_role_and_content(
Role.ASSISTANT, content_item.text
).with_channel("analysis")
)
elif item.type == "function_call":
function_call_map[item.call_id] = item
messages.append(
Message.from_role_and_content(Role.ASSISTANT, item.arguments)
.with_recipient(f"functions.{item.name}")
.with_channel("commentary")
)
elif item.type == "function_call_output":
function_call = function_call_map.get(item.call_id, None)
if not function_call:
raise ValueError(f"Function call {item.call_id} not found")
messages.append(
Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{function_call.name}"),
item.output,
).with_recipient("assistant").with_channel("commentary")
)
conversation = Conversation.from_messages(messages)
initial_tokens = encoding.render_conversation_for_completion(
conversation, Role.ASSISTANT
)
print(encoding.decode_utf8(initial_tokens))
response_id = f"resp_{uuid.uuid4().hex}"
def store_callback(rid: str, req: ResponsesRequest, resp: ResponseObject):
responses_store[rid] = (req, resp)
event_stream = StreamResponsesEvents(
initial_tokens,
body,
as_sse=body.stream,
request=request,
response_id=response_id,
store_callback=store_callback,
browser_tool=browser_tool,
)
if body.stream:
return StreamingResponse(event_stream.run(), media_type="text/event-stream")
else:
last_event = None
async for event in event_stream.run():
last_event = event
return last_event.response
return app