import os from collections import defaultdict from multiprocessing.pool import ThreadPool from typing import Any, Callable import jinja2 import numpy as np from tqdm import tqdm from .types import EvalResult, Message, SingleEvalResult HTML_JINJA = """

Prompt conversation

{% for message in prompt_messages %} {{ message_to_html(message) | safe }} {% endfor %}

Sampled message

{{ message_to_html(next_message) | safe }}

Results

Correct Answer: {{ correct_answer }}

Extracted Answer: {{ extracted_answer }}

Score: {{ score }}

""" def _compute_stat(values: list, stat: str): if stat == "mean": return np.mean(values) elif stat == "std": return np.std(values) elif stat == "min": return np.min(values) elif stat == "max": return np.max(values) elif stat == "n_samples": return len(values) elif stat == "bootstrap_std": return np.std( [np.mean(np.random.choice(values, len(values))) for _ in range(1000)] ) else: raise ValueError(f"Unknown {stat =}") def aggregate_results( single_eval_results: list[SingleEvalResult], default_stats: tuple[str, ...] = ("mean", "std"), name2stats: dict[str, tuple[str]] | None = None, ) -> EvalResult: """ Aggregate results from multiple evaluations into a single EvalResult. """ name2stats = name2stats or {} name2values = defaultdict(list) htmls = [] convos = [] metadata = [] for single_eval_result in single_eval_results: for name, value in single_eval_result.metrics.items(): name2values[name].append(value) if single_eval_result.score is not None: name2values["score"].append(single_eval_result.score) htmls.append(single_eval_result.html) convos.append(single_eval_result.convo) metadata.append(single_eval_result.example_level_metadata) final_metrics = {} for name, values in name2values.items(): stats = name2stats.get(name, default_stats) for stat in stats: key = name if stat == "mean" else f"{name}:{stat}" final_metrics[key] = _compute_stat(values, stat) return EvalResult( score=final_metrics.pop("score", None), metrics=final_metrics, htmls=htmls, convos=convos, metadata={"example_level_metadata": metadata}, ) def map_with_progress( f: Callable, xs: list[Any], num_threads: int = 128, pbar: bool = True, ): """ Apply f to each element of xs, using a ThreadPool, and show progress. """ pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x if os.getenv("debug"): return list(map(f, pbar_fn(xs, total=len(xs)))) else: with ThreadPool(min(num_threads, len(xs))) as pool: return list(pbar_fn(pool.imap_unordered(f, xs), total=len(xs))) jinja_env = jinja2.Environment( loader=jinja2.BaseLoader(), undefined=jinja2.StrictUndefined, autoescape=jinja2.select_autoescape(["html", "xml"]), ) _message_template = """
{{ role }} {% if variant %}({{ variant }}){% endif %}
{{ content }}
""" def message_to_html(message: Message) -> str: """ Generate HTML snippet (inside a
) for a message. """ return jinja_env.from_string(_message_template).render( role=message["role"], content=message["content"], variant=message.get("variant", None), ) jinja_env.globals["message_to_html"] = message_to_html _report_template = """ {% if metrics %}

Metrics

{% for name, value in metrics.items() %} {% endfor %}
Metric Value
Score {{ score | float | round(3) }}
{{ name }} {{ value }}
{% endif %}

Examples

{% for html in htmls %} {{ html | safe }}
{% endfor %} """ def make_report(eval_result: EvalResult) -> str: """ Create a standalone HTML report from an EvalResult. """ return jinja_env.from_string(_report_template).render( score=eval_result.score, metrics=eval_result.metrics, htmls=eval_result.htmls, )