Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
697 lines
22 KiB
Python
697 lines
22 KiB
Python
import contextvars
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import json
|
|
import re
|
|
import textwrap
|
|
from typing import Any, AsyncIterator, Callable, ParamSpec, Sequence
|
|
from urllib.parse import quote, unquote
|
|
|
|
import pydantic
|
|
import structlog
|
|
import tiktoken
|
|
from aiohttp import ClientSession
|
|
from openai_harmony import (
|
|
Author,
|
|
Content,
|
|
Message,
|
|
Role,
|
|
TextContent,
|
|
ToolNamespaceConfig
|
|
)
|
|
|
|
from ..tool import Tool
|
|
|
|
# from functions import Function, from_python
|
|
from .backend import (
|
|
VIEW_SOURCE_PREFIX,
|
|
Backend,
|
|
BackendError,
|
|
maybe_truncate,
|
|
)
|
|
from .page_contents import Extract, PageContents
|
|
|
|
logger = structlog.stdlib.get_logger(component=__name__)
|
|
|
|
|
|
# TODO(zhuohan): Use the correct encoding at release
|
|
ENC_NAME = "o200k_base"
|
|
FIND_PAGE_LINK_FORMAT = "# 【{idx}†{title}】"
|
|
PARTIAL_INITIAL_LINK_PATTERN = re.compile(r"^[^【】]*】")
|
|
PARTIAL_FINAL_LINK_PATTERN = re.compile(
|
|
r"【\d*(?:†(?P<content>[^†】]*)(?:†[^†】]*)?)?$"
|
|
)
|
|
LINK_PATTERN = re.compile(r"【\d+†(?P<content>[^†】]+)(?:†[^†】]+)?】")
|
|
|
|
CITATION_OUTPUT_PATTERN = re.compile(r"【(?P<cursor>\d+)†(?P<content>[^†】]+)(?:†[^†】]+)?】")
|
|
|
|
CallParams = ParamSpec("CallParams")
|
|
|
|
|
|
_P = ParamSpec("_P")
|
|
_live_function_name = contextvars.ContextVar[str]("_live_function_name")
|
|
|
|
|
|
class ToolUsageError(Exception):
|
|
pass
|
|
|
|
|
|
def function_the_model_can_call(
|
|
fn: Callable[_P, AsyncIterator[Message]],
|
|
) -> Callable[_P, AsyncIterator[Message]]:
|
|
fn.__fn_calling_tool_fn_type__ = "function_the_model_can_call" # type: ignore
|
|
|
|
@functools.wraps(fn)
|
|
async def inner(*args: _P.args, **kwargs: _P.kwargs) -> AsyncIterator[Message]:
|
|
token = _live_function_name.set(fn.__name__)
|
|
try:
|
|
async for m in fn(*args, **kwargs):
|
|
yield m
|
|
finally:
|
|
_live_function_name.reset(token)
|
|
|
|
return inner
|
|
|
|
|
|
@functools.cache
|
|
def _tiktoken_vocabulary_lengths(enc_name: str) -> list[int]:
|
|
encoding = tiktoken.get_encoding(enc_name)
|
|
results = []
|
|
for i in range(encoding.n_vocab):
|
|
try:
|
|
results.append(len(encoding.decode([i])))
|
|
except Exception as e:
|
|
results.append(1)
|
|
return results
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Tokens:
|
|
tokens: list[int]
|
|
tok2idx: list[int] # Offsets = running sum of lengths.
|
|
|
|
|
|
@functools.cache
|
|
def max_chars_per_token(enc_name: str) -> int:
|
|
"""Typical value is 128, but let's be safe."""
|
|
tok_lens = _tiktoken_vocabulary_lengths(enc_name)
|
|
return max(tok_lens)
|
|
|
|
|
|
def get_tokens(text: str, enc_name: str) -> Tokens:
|
|
encoding = tiktoken.get_encoding(enc_name)
|
|
tokens = encoding.encode(text, disallowed_special=())
|
|
_vocabulary_lenghts = _tiktoken_vocabulary_lengths(enc_name)
|
|
tok2idx = [0] + list(itertools.accumulate(_vocabulary_lenghts[i] for i in tokens))[
|
|
:-1
|
|
]
|
|
result = Tokens(tokens=tokens, tok2idx=tok2idx)
|
|
return result
|
|
|
|
|
|
def get_end_loc(
|
|
loc: int,
|
|
num_lines: int,
|
|
total_lines: int,
|
|
lines: list[str],
|
|
view_tokens: int,
|
|
encoding_name: str,
|
|
) -> int:
|
|
if num_lines <= 0:
|
|
# COMPUTE NUMBER OF LINES TO SHOW
|
|
txt = join_lines(lines[loc:], add_line_numbers=True, offset=loc)
|
|
# if the text is very short, no need to truncate at all
|
|
# at least one char per token
|
|
if len(txt) > view_tokens:
|
|
# limit the amount of text we tokenize here
|
|
upper_bound = max_chars_per_token(encoding_name)
|
|
tok2idx = get_tokens(
|
|
txt[: (view_tokens + 1) * upper_bound], encoding_name
|
|
).tok2idx
|
|
if len(tok2idx) > view_tokens:
|
|
end_idx = tok2idx[view_tokens]
|
|
num_lines = txt[:end_idx].count("\n") + 1 # round up
|
|
else:
|
|
num_lines = total_lines
|
|
else:
|
|
num_lines = total_lines
|
|
|
|
return min(loc + num_lines, total_lines)
|
|
|
|
|
|
def get_page_metadata(
|
|
curr_page: PageContents,
|
|
) -> dict[str, str | None | dict[str, str] | list[str]]:
|
|
"""Some attributes of the current page."""
|
|
page_metadata: dict[str, str | None | dict[str, str] | list[str]] = {
|
|
"url": curr_page.url,
|
|
"title": curr_page.title,
|
|
}
|
|
return page_metadata
|
|
|
|
|
|
def join_lines(
|
|
lines: list[str], add_line_numbers: bool = False, offset: int = 0
|
|
) -> str:
|
|
if add_line_numbers:
|
|
return "\n".join([f"L{i + offset}: {line}" for i, line in enumerate(lines)])
|
|
else:
|
|
return "\n".join(lines)
|
|
|
|
|
|
def wrap_lines(text: str, width: int = 80) -> list[str]:
|
|
lines = text.split("\n")
|
|
wrapped = itertools.chain.from_iterable(
|
|
(
|
|
textwrap.wrap(
|
|
line, width=width, replace_whitespace=False, drop_whitespace=False
|
|
)
|
|
if line
|
|
else [""]
|
|
) # preserve empty lines
|
|
for line in lines
|
|
)
|
|
return list(wrapped)
|
|
|
|
|
|
def strip_links(text: str) -> str:
|
|
text = re.sub(PARTIAL_INITIAL_LINK_PATTERN, "", text)
|
|
text = re.sub(PARTIAL_FINAL_LINK_PATTERN, lambda mo: mo.group("content"), text)
|
|
text = re.sub(LINK_PATTERN, lambda mo: mo.group("content"), text)
|
|
return text
|
|
|
|
|
|
def maybe_get_function_args(
|
|
message: Message, tool_name: str = "browser"
|
|
) -> dict[str, Any] | None:
|
|
if not message.recipient.startswith(f"{tool_name}."):
|
|
return None
|
|
|
|
contents = ""
|
|
if len(message.content) == 1 and isinstance(message.content[0], TextContent):
|
|
contents = message.content[0].text
|
|
|
|
if not contents:
|
|
return {}
|
|
|
|
try:
|
|
parsed_contents = json.loads(contents)
|
|
if isinstance(parsed_contents, dict):
|
|
return parsed_contents
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
async def run_find_in_page(
|
|
pattern: str,
|
|
page: PageContents,
|
|
max_results: int = 50,
|
|
num_show_lines: int = 4,
|
|
) -> PageContents:
|
|
lines = wrap_lines(text=page.text)
|
|
txt = join_lines(lines, add_line_numbers=False)
|
|
without_links = strip_links(txt)
|
|
lines = without_links.split("\n")
|
|
|
|
result_chunks, snippets = [], []
|
|
line_idx, match_idx = 0, 0
|
|
while line_idx < len(lines):
|
|
line = lines[line_idx]
|
|
if pattern not in line.lower():
|
|
line_idx += 1
|
|
continue
|
|
snippet = "\n".join(lines[line_idx : line_idx + num_show_lines])
|
|
link_title = FIND_PAGE_LINK_FORMAT.format(
|
|
idx=f"{match_idx}", title=f"match at L{line_idx}"
|
|
)
|
|
result_chunks.append(f"{link_title}\n{snippet}")
|
|
snippets.append(
|
|
Extract(
|
|
url=page.url, text=snippet, title=f"#{match_idx}", line_idx=line_idx
|
|
)
|
|
)
|
|
if len(result_chunks) == max_results:
|
|
break
|
|
match_idx += 1
|
|
line_idx += num_show_lines
|
|
|
|
urls = [page.url for _ in result_chunks]
|
|
|
|
if result_chunks:
|
|
display_text = "\n\n".join(result_chunks)
|
|
else:
|
|
display_text = f"No `find` results for pattern: `{pattern}`"
|
|
|
|
result_page = PageContents(
|
|
url=f"{page.url}/find?pattern={quote(pattern)}",
|
|
title=f"Find results for text: `{pattern}` in `{page.title}`",
|
|
text=display_text,
|
|
urls={str(i): url for i, url in enumerate(urls)},
|
|
snippets={str(i): snip for i, snip in enumerate(snippets)},
|
|
)
|
|
return result_page
|
|
|
|
|
|
def handle_errors(
|
|
func: Callable[CallParams, AsyncIterator["Message"]],
|
|
) -> Callable[CallParams, AsyncIterator["Message"]]:
|
|
@functools.wraps(func)
|
|
async def inner(
|
|
*args: CallParams.args, **kwargs: CallParams.kwargs
|
|
) -> AsyncIterator[Message]:
|
|
tool = args[0]
|
|
# Could be cool to type this explicitly, but mypy makes it hard
|
|
assert isinstance(tool, SimpleBrowserTool)
|
|
try:
|
|
async for msg in func(*args, **kwargs):
|
|
yield msg
|
|
except (ToolUsageError, BackendError) as e:
|
|
yield tool.make_error_message(e)
|
|
|
|
return inner
|
|
|
|
|
|
class SimpleBrowserState(pydantic.BaseModel):
|
|
# maps page url to page contents
|
|
pages: dict[str, PageContents] = pydantic.Field(default_factory=dict)
|
|
# a sequential list of page urls
|
|
page_stack: list[str] = pydantic.Field(default_factory=list)
|
|
|
|
@property
|
|
def current_cursor(self) -> int:
|
|
return len(self.page_stack) - 1
|
|
|
|
def add_page(self, page: PageContents) -> None:
|
|
self.pages[page.url] = page
|
|
self.page_stack.append(page.url)
|
|
|
|
def get_page(self, cursor: int = -1) -> PageContents:
|
|
if self.current_cursor < 0:
|
|
raise ToolUsageError("No pages to access!")
|
|
if cursor == -1 or cursor == self.current_cursor:
|
|
return self.pages[self.page_stack[-1]]
|
|
try:
|
|
page_url = self.page_stack[cursor]
|
|
except TypeError as e:
|
|
raise ToolUsageError(
|
|
f"`cursor` should be an integer, not `{type(cursor).__name__}`"
|
|
) from e
|
|
except IndexError as e:
|
|
raise ToolUsageError(
|
|
f"Cursor `{cursor}` is out of range. "
|
|
f"Available cursor indices: [0 - {self.current_cursor}]."
|
|
) from e
|
|
return self.pages[page_url]
|
|
|
|
def get_page_by_url(self, url: str) -> PageContents | None:
|
|
if url in self.pages:
|
|
return self.pages[url]
|
|
return None
|
|
|
|
def pop_page_stack(self) -> None:
|
|
assert self.current_cursor >= 0, "No page to pop!"
|
|
self.page_stack.pop()
|
|
|
|
|
|
class SimpleBrowserTool(Tool):
|
|
def __init__(
|
|
self,
|
|
backend: Backend,
|
|
encoding_name: str = ENC_NAME,
|
|
max_search_results: int = 20,
|
|
tool_state: dict[str, Any] | None = None,
|
|
view_tokens: int = 1024,
|
|
name: str = "browser",
|
|
):
|
|
assert name == "browser"
|
|
self.backend = backend
|
|
if tool_state is None:
|
|
self.tool_state = SimpleBrowserState()
|
|
else:
|
|
self.tool_state = SimpleBrowserState.model_validate(tool_state)
|
|
|
|
self.encoding_name = encoding_name
|
|
self.max_search_results = max_search_results
|
|
self.view_tokens = view_tokens
|
|
|
|
def get_tool_state(self) -> dict[str, Any]:
|
|
return {"tool_state": self.tool_state.model_dump()}
|
|
|
|
@classmethod
|
|
def get_tool_name(cls) -> str:
|
|
return "browser"
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self.get_tool_name()
|
|
|
|
@property
|
|
def tool_config(self) -> ToolNamespaceConfig:
|
|
config = ToolNamespaceConfig.browser()
|
|
config.name = self.name
|
|
config.description = """Tool for browsing.
|
|
The `cursor` appears in brackets before each browsing display: `[{cursor}]`.
|
|
Cite information from the tool using the following format:
|
|
`【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.
|
|
Do not quote more than 10 words directly from the tool output.
|
|
sources=""" + self.backend.source
|
|
return config
|
|
|
|
@property
|
|
def instruction(self) -> str:
|
|
return self.tool_config.description
|
|
|
|
def _render_browsing_display(
|
|
self,
|
|
tether_id: int,
|
|
result: str,
|
|
summary: str | None = None,
|
|
):
|
|
to_return = ""
|
|
# Always show summaries.
|
|
if summary:
|
|
to_return += summary
|
|
to_return += result
|
|
to_return = f"[{tether_id}] {to_return}"
|
|
return to_return
|
|
|
|
def _make_response(
|
|
self,
|
|
page: PageContents,
|
|
cursor: int,
|
|
body: str,
|
|
scrollbar: str,
|
|
) -> Message:
|
|
domain = maybe_truncate(unquote(page.url))
|
|
header = f"{page.title}"
|
|
if domain:
|
|
header += f" ({domain})"
|
|
header += f"\n**{scrollbar}**\n\n"
|
|
|
|
content = TextContent(text=self._render_browsing_display(cursor, body, header))
|
|
return self.make_response(
|
|
content=content, metadata=get_page_metadata(self.tool_state.get_page())
|
|
)
|
|
|
|
async def show_page(self, loc: int = 0, num_lines: int = -1) -> Message:
|
|
page = self.tool_state.get_page()
|
|
cursor = self.tool_state.current_cursor
|
|
lines = wrap_lines(text=page.text)
|
|
total_lines = len(lines)
|
|
|
|
if loc >= total_lines:
|
|
err_msg = (
|
|
f"Invalid location parameter: `{loc}`. "
|
|
f"Cannot exceed page maximum of {total_lines - 1}."
|
|
)
|
|
raise ToolUsageError(err_msg)
|
|
|
|
end_loc = get_end_loc(
|
|
loc, num_lines, total_lines, lines, self.view_tokens, self.encoding_name
|
|
)
|
|
|
|
lines_to_show = lines[loc:end_loc]
|
|
body = join_lines(lines_to_show, add_line_numbers=True, offset=loc)
|
|
|
|
scrollbar = f"viewing lines [{loc} - {end_loc - 1}] of {total_lines - 1}"
|
|
return self._make_response(page, cursor, body, scrollbar)
|
|
|
|
async def show_page_safely(self, loc: int = 0, num_lines: int = -1) -> Message:
|
|
try:
|
|
return await self.show_page(loc=loc, num_lines=num_lines)
|
|
except ToolUsageError as e:
|
|
self.tool_state.pop_page_stack()
|
|
raise e
|
|
|
|
async def _open_url(self, url: str, direct_url_open: bool) -> PageContents:
|
|
"""Use the cache, if available."""
|
|
backend = self.backend
|
|
# direct_url_open should be regarded as a refresh
|
|
if not direct_url_open and (page := self.tool_state.get_page_by_url(url)):
|
|
assert page.url == url
|
|
return page
|
|
|
|
try:
|
|
async with ClientSession() as session:
|
|
page = await backend.fetch(url, session=session)
|
|
return page
|
|
except Exception as e:
|
|
msg = maybe_truncate(str(e))
|
|
logger.warning("Error fetching URL in lean browser tool", exc_info=e)
|
|
raise BackendError(
|
|
f"Error fetching URL `{maybe_truncate(url)}`: {msg}"
|
|
) from e
|
|
|
|
def make_error_message(self, error: Exception) -> Message:
|
|
"""Uses the message creation codepath from the base class."""
|
|
error_name = error.__class__.__name__
|
|
content = TextContent(text=str(error))
|
|
return self.make_response(content=content)
|
|
|
|
@function_the_model_can_call
|
|
@handle_errors
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
topn: int = 10,
|
|
top_n: int = 10,
|
|
source: str | None = None,
|
|
) -> AsyncIterator[Message]:
|
|
del topn
|
|
del top_n
|
|
try:
|
|
async with ClientSession() as session:
|
|
search_page = await self.backend.search(
|
|
query=query,
|
|
topn=self.max_search_results,
|
|
session=session,
|
|
)
|
|
except Exception as e:
|
|
msg = maybe_truncate(str(e))
|
|
raise BackendError(f"Error during search for `{query}`: {msg}") from e
|
|
|
|
self.tool_state.add_page(search_page)
|
|
yield await self.show_page_safely(loc=0)
|
|
|
|
@function_the_model_can_call
|
|
@handle_errors
|
|
async def open(
|
|
self,
|
|
id: int | str = -1,
|
|
cursor: int = -1,
|
|
loc: int = -1,
|
|
num_lines: int = -1,
|
|
view_source: bool = False,
|
|
source: str | None = None,
|
|
) -> AsyncIterator[Message]:
|
|
curr_page: PageContents | None = None
|
|
stay_on_current_page = False
|
|
direct_url_open = False
|
|
if isinstance(id, str):
|
|
snippet = None
|
|
url = id
|
|
direct_url_open = True
|
|
else: # Operate on a previously opened page
|
|
curr_page = self.tool_state.get_page(cursor)
|
|
|
|
if id >= 0: # click a link
|
|
try:
|
|
url = curr_page.urls[str(id)]
|
|
except KeyError as e:
|
|
raise ToolUsageError(f"Invalid link id `{id}`.") from e
|
|
snippet = (curr_page.snippets or {}).get(str(id))
|
|
if snippet and curr_page.url == "":
|
|
# current page is a search result page
|
|
assert isinstance(snippet, Extract)
|
|
else: # navigate to new position on the current page
|
|
if not view_source:
|
|
stay_on_current_page = True
|
|
url = curr_page.url
|
|
snippet = None
|
|
|
|
new_page: PageContents
|
|
if view_source:
|
|
url = f"{VIEW_SOURCE_PREFIX}{url}"
|
|
snippet = None
|
|
if stay_on_current_page:
|
|
assert curr_page is not None
|
|
new_page = curr_page
|
|
else:
|
|
new_page = await self._open_url(url, direct_url_open)
|
|
|
|
self.tool_state.add_page(new_page)
|
|
|
|
if loc < 0: # unset
|
|
if snippet is not None and snippet.line_idx is not None:
|
|
loc = snippet.line_idx
|
|
if loc > 4:
|
|
loc -= 4
|
|
else:
|
|
loc = 0
|
|
yield await self.show_page_safely(loc=loc, num_lines=num_lines)
|
|
|
|
@function_the_model_can_call
|
|
@handle_errors
|
|
async def find(self, pattern: str, cursor: int = -1) -> AsyncIterator[Message]:
|
|
page = self.tool_state.get_page(cursor)
|
|
if page.snippets is not None:
|
|
raise ToolUsageError(
|
|
"Cannot run `find` on search results page or find results page"
|
|
)
|
|
|
|
pc = await run_find_in_page(
|
|
pattern=str(pattern).lower(),
|
|
page=page,
|
|
)
|
|
self.tool_state.add_page(pc)
|
|
yield await self.show_page_safely(loc=0)
|
|
|
|
def make_response(
|
|
self,
|
|
content: Content,
|
|
*,
|
|
metadata: dict[str, Any] | None = None,
|
|
author: Author | None = None,
|
|
) -> Message:
|
|
"""
|
|
Make a response message.
|
|
|
|
Should be used from `@function_the_model_can_call` if author is not provided.
|
|
"""
|
|
if author is None:
|
|
tool_name = self.get_tool_name()
|
|
function_name = _live_function_name.get()
|
|
assert function_name is not None
|
|
author = Author(role=Role.TOOL, name=f"{tool_name}.{function_name}")
|
|
|
|
return Message(
|
|
author=author,
|
|
content=[content],
|
|
).with_recipient("assistant")
|
|
|
|
def process_arguments(self, message: Message) -> dict[str, Any]:
|
|
function_args = maybe_get_function_args(message, tool_name=self.name)
|
|
if function_args is None:
|
|
raise ValueError("Invalid function arguments")
|
|
|
|
if "cursor" in function_args and function_args["cursor"] >= 0:
|
|
page = self.tool_state.get_page(cursor=function_args["cursor"])
|
|
if "id" in function_args:
|
|
function_args["url"] = page.urls[str(function_args["id"])]
|
|
else:
|
|
function_args["url"] = page.url
|
|
elif "id" in function_args and isinstance(function_args["id"], str):
|
|
function_args["url"] = function_args["id"]
|
|
return function_args
|
|
|
|
async def _process(self, message: Message) -> AsyncIterator[Message]:
|
|
def make_error_message(error: str) -> Message:
|
|
return self.make_response(
|
|
content=TextContent(text=json.dumps({"error": error})),
|
|
author=Author(role=Role.TOOL, name=message.recipient),
|
|
)
|
|
|
|
function_args = maybe_get_function_args(message, tool_name=self.name)
|
|
if function_args is None:
|
|
yield make_error_message("Invalid function arguments")
|
|
return
|
|
|
|
_, function_name = message.recipient.split(".")
|
|
if function_name not in ["search", "open", "find"]:
|
|
yield make_error_message(f"Unknown function: {function_name}")
|
|
return
|
|
|
|
if function_name == "search":
|
|
async for msg in self.search(**function_args):
|
|
yield msg
|
|
elif function_name == "open":
|
|
async for msg in self.open(**function_args):
|
|
yield msg
|
|
elif function_name == "find":
|
|
async for msg in self.find(**function_args):
|
|
yield msg
|
|
else:
|
|
raise ValueError("should not be here")
|
|
|
|
|
|
def normalize_citations(self, old_content: str, hide_partial_citations: bool = False) -> tuple[str, list[dict[str, Any]], bool]:
|
|
"""
|
|
Returns a tuple of (new_message, annotations, has_partial_citations)
|
|
- new_message: Message with citations replaced by ([domain](url))
|
|
- annotations: list of dicts with start_index, end_index, and title (url)
|
|
- has_partial_citations: whether the text includes an unfinished citation
|
|
"""
|
|
|
|
has_partial_citations = PARTIAL_FINAL_LINK_PATTERN.search(old_content) is not None
|
|
if hide_partial_citations and has_partial_citations:
|
|
old_content = PARTIAL_FINAL_LINK_PATTERN.sub("", old_content)
|
|
|
|
matches = []
|
|
for match in CITATION_OUTPUT_PATTERN.finditer(old_content):
|
|
cursor = match.group("cursor")
|
|
content = match.group("content")
|
|
start_idx = match.start()
|
|
end_idx = match.end()
|
|
matches.append({
|
|
"cursor": cursor,
|
|
"content": content,
|
|
"start": start_idx,
|
|
"end": end_idx
|
|
})
|
|
|
|
# Build a mapping from cursor to url
|
|
cursor_to_url = {}
|
|
for idx, url in enumerate(self.tool_state.page_stack):
|
|
cursor_to_url[str(idx)] = url
|
|
|
|
def extract_domain(url):
|
|
try:
|
|
return unquote(url).split("/")[2]
|
|
except Exception:
|
|
return url
|
|
|
|
new_content = ""
|
|
last_idx = 0
|
|
annotations = []
|
|
running_offset = 0 # Offset due to length changes in replacements
|
|
|
|
for m in matches:
|
|
cursor = m["cursor"]
|
|
url = cursor_to_url.get(cursor, None)
|
|
orig_start = m["start"]
|
|
orig_end = m["end"]
|
|
|
|
# Add text before the citation
|
|
new_content += old_content[last_idx:orig_start]
|
|
|
|
if url:
|
|
domain = extract_domain(url)
|
|
replacement = f" ([{domain}]({url})) "
|
|
# The start and end indices in the new content
|
|
start_index = len(new_content)
|
|
end_index = start_index + len(replacement)
|
|
annotations.append({
|
|
"start_index": start_index,
|
|
"end_index": end_index,
|
|
"title": domain,
|
|
"url": url,
|
|
"type": "url_citation",
|
|
})
|
|
new_content += replacement
|
|
else:
|
|
# Keep the original citation format if cursor is missing
|
|
replacement = old_content[orig_start:orig_end]
|
|
start_index = len(new_content)
|
|
end_index = start_index + len(replacement)
|
|
# No annotation for missing url, but could add if desired
|
|
new_content += replacement
|
|
|
|
last_idx = orig_end
|
|
|
|
new_content += old_content[last_idx:]
|
|
return new_content, annotations, has_partial_citations
|
|
|