This commit is contained in:
2025-10-23 16:21:52 +08:00
parent 5e21419a67
commit 9f0a0fbcdc
25 changed files with 38489 additions and 1 deletions

2
.gitattributes vendored Normal file
View File

@@ -0,0 +1,2 @@
# SCM syntax highlighting & preventing 3-way merges
pixi.lock merge=binary linguist-language=YAML linguist-generated=true

7
.gitignore vendored
View File

@@ -1,2 +1,7 @@
.venv/
.DS_Store
.DS_Store
__pycache__/
*.pyc# pixi environments
.pixi/*
!.pixi/config.toml
*.egg-info/

115
README.md
View File

@@ -33,6 +33,100 @@ export HF_HUB_OFFLINE=1
export HF_ENDPOINT=https://hf-mirror.com
```
## 会话编排服务FastAPI / MCP
使用 `uv run embedding-backend-api` 可以启动一个同时兼容 FastAPI 与 FastMCP 的后端服务。该服务监听 `/sessions` 路径,负责按需拉起 `embedding-atlas` 容器并在 10 小时后自动清理。
```bash
uv run embedding-backend-api
# 或以 MCP 模式启动stdio
uv run embedding-backend-mcp
```
### REST API 用法
```bash
curl -X POST http://localhost:9000/sessions \
-H 'Content-Type: application/json' \
-d '{
"data_url": "https://example.com/data.csv",
"extra_args": ["--text", "smiles"]
}'
# 关闭会话
curl -X DELETE http://localhost:9000/sessions/<session_id>
# 查看当前会话
curl http://localhost:9000/sessions
```
请求体支持 `session_id``port``auto_remove_seconds``environment` 等字段;省略 `session_id` 时会自动生成并在响应中返回。返回结果包含容器名称与可访问的前端地址,供 FastAPI 或 MCP 客户端转发给前端使用。
如果希望以交互方式测试 API可在浏览器访问 `http://localhost:9000/docs` 打开自动生成的 Swagger UI。
#### 请求体字段说明
- `session_id`:可选,自定义的会话 ID省略时系统自动生成。
- `data_url`:必填,指向 CSV/Parquet 的 HTTPS 链接,会被下载到 orchestrator 与 DinD 共享的 `/sessions/<session_id>/` 中。
- `input_filename`:可选,保存到共享卷时使用的文件名;默认根据 URL 推断。
- `extra_args`:可选,附加给 `embedding-atlas` CLI 的参数数组,例如 `["--text", "smiles"]`
- `environment`:可选,注入到容器内的环境变量映射。
- `labels`:可选,附加到容器上的 Docker labels在默认的 `embedding-backend.*` 标签之外),便于自定义监控或追踪。
- `image`:可选,覆盖默认的 `embedding-atlas` 镜像名。
- `host`:可选,传递给 `embedding-atlas --host` 的值,默认 `0.0.0.0`
- `port`:可选,显式指定宿主机端口;若缺省会自动在配置的区间内分配。
- `auto_remove_seconds`:可选,覆盖默认的 10 小时存活时间。
示例请求:
```bash
curl -X POST http://localhost:9000/sessions \
-H 'Content-Type: application/json' \
-d '{
"session_id": "demo-session",
"data_url": "https://example.com/data.csv",
"input_filename": "data.csv",
"extra_args": ["--text", "smiles"],
"environment": {"HF_ENDPOINT": "https://hf-mirror.com"},
"labels": {"team": "chem"},
"auto_remove_seconds": 7200
}'
```
#### GET /sessions 响应示例
无查询参数,返回当前所有会话的列表:
```json
{
"sessions": [
{
"session_id": "demo-session",
"container_id": "e556c0f1c35b...",
"container_name": "embedding-atlas_demo-session",
"port": 6000,
"host": "0.0.0.0",
"started_at": "2025-09-22T14:37:25.206038Z",
"expires_at": "2025-09-23T00:37:24.416241Z",
"dataset_path": "/sessions/demo-session/data.csv"
}
]
}
```
- `sessions` 数组为空时表示当前没有存活的容器。
- `dataset_path` 指向共享卷内对应数据集的绝对路径。
### MCP 集成
根目录的 `fastmcp.json` 示例可直接将本项目注册为 MCP 工具:
```bash
uv run embedding-backend-mcp
```
FastMCP 客户端加载该配置后,可用标准 MCP 协议转发同一套 REST 接口,从而与传统后端保持一致的行为。
## 命令行生成嵌入可视化交互
```bash
@@ -79,3 +173,24 @@ uv run python script/merge_splits.py --input-dir splits_v2/ --output data/drugba
```bash
uv run embedding-atlas data/drugbank_split_merge.csv --text smiles
```
## 容器化部署
项目提供 `docker/` 目录用于快速启动后端:
1. 先构建可被 orchestrator 复用的 `embedding-atlas` 镜像:
```bash
docker build -f docker/embedding-atlas.Dockerfile -t embedding-atlas:latest .
```
2. 启动 DIND + orchestrator 组合:
```bash
docker compose -f docker/docker-compose.yml up --build
```
- `engine` 服务运行 `docker:dind`,对外暴露 `tcp://localhost:2375` 供 orchestrator 通过 socket 管理容器;
- 两个服务通过 `sessions-data` 卷共享 `/sessions` 目录,后端会把下载的数据放到这里;
- `orchestrator` 服务运行 FastAPI/FastMCP 后端,默认开放 `http://localhost:9000`
- 会话缓存保存在 `sessions-data` 卷,可在容器重启时保留下载的数据集。
如需调整默认镜像或端口,可在 `docker/docker-compose.yml` 中覆盖 `EMBEDDING_*` 环境变量,或在部署时通过 `.env` 文件注入。
> 注意:由于 orchestrator 运行在容器内并通过 DinD 调度新容器,如需在宿主机直接访问 `embedding-atlas` 的 Web UI需要确保相应端口从 `docker_engine_1` 转发到宿主,可按需求使用 `podman port` 或额外的反向代理。

90
data/ring12_20/README.md Normal file
View File

@@ -0,0 +1,90 @@
Your Filtered Macrolactone Database
11036 compounds have been filtered from MacrolactoneDB based on your specified inputs.
```bash
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text ecfp4_binary
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text tanimoto_top_neighbors
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text smiles
```
## 嵌入和投影优化
### projection_x 和 projection_y 的生成过程
UMAP 降维计算
这两个坐标是通过 _run_umap() 函数生成的,该函数使用 UMAP 算法将高维嵌入向量降维到 2D 空间 projection.py:64-88 。
具体流程如下:
计算最近邻 - 首先使用 nearest_neighbors() 计算每个点的 k 个最近邻 projection.py:76-83
UMAP 投影 - 然后使用预计算的邻居信息进行 UMAP 降维 projection.py:85-86
坐标分配 - 结果的第一列成为 projection_x第二列成为 projection_y projection.py:259-260
默认参数设置
UMAP 算法使用以下默认参数:
邻居数量: 15 个最近邻 projection.py:74
距离度量: cosine 距离 projection.py:73
在不同数据类型中的应用
文本数据处理
对于您的 SMILES 分子数据,系统首先使用 SentenceTransformers 生成文本嵌入,然后通过 UMAP 降维 projection.py:251-260 。
预计算向量处理
如果您有预计算的 ECFP4 向量,系统会直接对这些向量进行 UMAP 降维 projection.py:311-318 。
可视化中的作用
在前端可视化界面中,这些坐标用作:
散点图的 X/Y 轴 - 每个数据点在 2D 空间中的位置
颜色编码的基础 - 可以根据坐标值进行颜色映射 embedding-atlas.md:68-70
演示数据示例
在项目的演示数据生成中,可以看到相同的处理流程:使用 SentenceTransformers 计算嵌入,然后通过 UMAP 生成 projection_x 和 projection_y 坐标 generate_demo_data.py:42-43 。
Notes
这些投影坐标的质量很大程度上取决于原始嵌入的质量和 UMAP 参数的选择。对于化学分子数据,使用专门的分子嵌入模型通常会产生更有意义的 2D 投影,其中化学结构相似的分子会在投影空间中聚集在一起。
### UMAP 参数调优
您可以通过调整 UMAP 参数来获得更好的可视化效果:
```bash
# 调整邻居数量和距离参数
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text smiles \
--umap-n-neighbors 30 \
--umap-min-dist 0.1 \
--umap-metric cosine \
--umap-random-state 42
```
## 自定义嵌入模型
对于化学分子数据,您可能想使用专门的模型:
并且符合:
模型支持范围
embedding-atlas 支持两种类型的自定义模型:
文本嵌入模型
对于文本数据(如您的 SMILES 分子数据),系统使用 SentenceTransformers 库 projection.py:118-126 。这意味着您可以使用任何与 SentenceTransformers 兼容的 Hugging Face 模型。
图像嵌入模型
对于图像数据,系统使用 transformers 库的 pipeline 功能 projection.py:168-180 。
模型格式要求
SentenceTransformers 兼容性
文本模型必须与 SentenceTransformers 库兼容 projection.py:98-99 。这包括:
专门训练用于句子嵌入的模型
支持 .encode() 方法的模型
能够输出固定维度向量的模型
```bash
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text smiles \
--umap-n-neighbors 30 \
--umap-min-dist 0.1 \
--umap-metric cosine \
--umap-random-state 42
```

89
data/ring12_20/counts.txt Normal file
View File

@@ -0,0 +1,89 @@
Target Organisms
Homo sapiens 815
Homo sapiens, None 180
Plasmodium falciparum 161
Hepatitis C virus, None 112
Homo sapiens, Plasmodium falciparum 63
Oryctolagus cuniculus 62
Mus musculus 60
Toxoplasma gondii 39
Homo sapiens, Rattus norvegicus 27
Mus musculus, Homo sapiens 24
None, Rattus norvegicus 23
Human immunodeficiency virus 1 20
Hepatitis C virus 18
Rattus norvegicus 17
Homo sapiens, Sus scrofa 11
Homo sapiens, Chlorocebus aethiops 10
Serratia marcescens 9
Escherichia coli 8
Oryctolagus cuniculus, Homo sapiens 7
Streptococcus pneumoniae 6
Oryctolagus cuniculus, Staphylococcus aureus, Raoultella planticola, Bacillus subtilis, Mus musculus, Micrococcus luteus, None, Escherichia coli, Plasmodium falciparum, Streptococcus pneumoniae, Homo sapiens, Escherichia coli K-12, Toxoplasma gondii 6
Plasmodium falciparum K1 5
Bacillus anthracis 5
Mus musculus, Homo sapiens, None 5
Bacillus anthracis, Homo sapiens 4
Candida albicans, Cryptococcus neoformans, Aspergillus fumigatus 4
Mus musculus, None 4
Plasmodium falciparum, Homo sapiens, None 4
None, Homo sapiens, Plasmodium falciparum 3
Bacillus subtilis, Homo sapiens 3
Oryctolagus cuniculus, Homo sapiens, None 3
Sus scrofa, Mus musculus, None, Plasmodium falciparum, Homo sapiens, Rattus norvegicus 2
Homo sapiens, None, Rattus norvegicus 2
Cryptococcus neoformans 2
Homo sapiens, None, Chlorocebus aethiops 2
Staphylococcus aureus 2
Candida albicans, Cryptococcus neoformans, Mycobacterium intracellulare, Aspergillus fumigatus 2
Mus musculus, None, Human immunodeficiency virus 1 2
Escherichia coli (strain K12) 2
Plasmodium falciparum 3D7, Homo sapiens 2
Aspergillus fumigatus 1
Sus scrofa 1
Saccharomyces cerevisiae S288c, Human immunodeficiency virus 1, Human herpesvirus 1, Plasmodium falciparum, None, Homo sapiens, Rattus norvegicus 1
Hepatitis C virus, Homo sapiens, None 1
Plasmodium falciparum 3D7 1
Bacillus subtilis 1
Mus musculus, Homo sapiens, None, Saccharomyces cerevisiae 1
Chlorocebus aethiops 1
Homo sapiens, Escherichia coli K-12, None 1
Hepatitis C virus, Homo sapiens, None, Rattus norvegicus 1
None, Homo sapiens, Human herpesvirus 1 1
Homo sapiens, None, Trypanosoma brucei brucei 1
Homo sapiens, None, Cryptococcus neoformans 1
Homo sapiens, Rattus norvegicus, Human immunodeficiency virus 1 1
None, Plasmodium falciparum, Escherichia coli, Streptococcus pneumoniae, Naegleria fowleri, Homo sapiens, Streptococcus, Toxoplasma gondii 1
Giardia intestinalis, Trypanosoma cruzi, Equus caballus, Bos taurus, Mus musculus, None, Plasmodium falciparum, Chlorocebus aethiops, Homo sapiens 1
Plasmodium falciparum NF54, Trypanosoma cruzi, Trypanosoma brucei rhodesiense, Rattus norvegicus 1
None, Homo sapiens, Plasmodium falciparum K1, Plasmodium falciparum 1
Saccharomyces cerevisiae S288c, Homo sapiens, None, Saccharomyces cerevisiae, Phytophthora sojae 1
Bacillus subtilis, Homo sapiens, Schistosoma mansoni, Saccharomyces cerevisiae, Giardia intestinalis 1
Streptococcus, Homo sapiens, None 1
Mus musculus, Homo sapiens, Rattus norvegicus 1
Homo sapiens, Spinacia oleracea 1
Human immunodeficiency virus 1, Mus musculus, None, Hepatitis C virus, Homo sapiens, Rattus norvegicus 1
None, Plasmodium falciparum, Trypanosoma brucei rhodesiense 1
Hepatitis C virus, None, Rattus norvegicus 1
Homo sapiens, Equus caballus 1
Plasmodium falciparum NF54, Trypanosoma cruzi, Trypanosoma brucei rhodesiense 1
Schistosoma mansoni, Influenza A virus 1
Leishmania chagasi, Trypanosoma cruzi 1
Candida albicans, Cryptococcus neoformans 1
None, Plasmodium falciparum 1
Caenorhabditis elegans 1
Bos taurus, Sus scrofa 1
Plasmodium falciparum, Enterococcus faecium 1
Homo sapiens, Gallus gallus 1
Homo sapiens, Escherichia coli 1
Plasmodium falciparum, Homo sapiens, None, Rattus norvegicus, Schistosoma mansoni 1
Homo sapiens, None, Influenza A virus 1
Mycobacterium tuberculosis, None 1
Escherichia coli, Homo sapiens, Toxoplasma gondii, None, Streptococcus pneumoniae 1
Bacillus subtilis, Oryctolagus cuniculus, Homo sapiens, Schistosoma mansoni, Giardia intestinalis 1
Homo sapiens, None, Rattus norvegicus, Escherichia coli O157:H7 1
Giardia intestinalis, Schistosoma mansoni, Mus musculus, None, Homo sapiens, Saccharomyces cerevisiae 1
Trypanosoma cruzi 1
Influenza A virus 1
Escherichia coli K-12 1
Human herpesvirus 4 (strain B95-8) 1

11037
data/ring12_20/temp.csv Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

41
docker/Dockerfile Normal file
View File

@@ -0,0 +1,41 @@
# syntax=docker/dockerfile:1
FROM python:3.12-slim AS base
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
build-essential \
curl \
git \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir uv
WORKDIR /app
ENV UV_PROJECT_ENVIRONMENT=/app/.venv
ENV UV_PIP_INDEX_URL=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
ENV HF_ENDPOINT=https://hf-mirror.com
ENV HF_HOME=/app/.cache/huggingface
ENV TRANSFORMERS_CACHE=/app/.cache/huggingface/transformers
ENV HF_DATASETS_CACHE=/app/.cache/huggingface/datasets
ENV EMBEDDING_ATLAS_DEVICE=cpu
RUN mkdir -p "$HF_HOME" "$TRANSFORMERS_CACHE" "$HF_DATASETS_CACHE"
VOLUME ["/app/.cache/huggingface"]
COPY pyproject.toml README.md ./
COPY src ./src
COPY script ./script
# Install dependencies with uv. When a lock file is present it will be used automatically.
RUN uv sync --no-dev || uv sync --no-dev
ENV PATH="/app/.venv/bin:$PATH"
ENV EMBEDDING_API_HOST=0.0.0.0
ENV EMBEDDING_API_PORT=9000
EXPOSE 9000

39
docker/docker-compose.yml Normal file
View File

@@ -0,0 +1,39 @@
version: "3.9"
services:
engine:
image: docker:25.0-dind
privileged: true
environment:
- DOCKER_TLS_CERTDIR=
command: ["--host=tcp://0.0.0.0:2375"]
ports:
- "2375:2375"
volumes:
- engine-data:/var/lib/docker
- sessions-data:/sessions
restart: unless-stopped
orchestrator:
# Ensure the embedding-atlas image exists (e.g. docker build -f docker/embedding-atlas.Dockerfile -t embedding-atlas:latest ..)
build:
context: ..
dockerfile: docker/Dockerfile
depends_on:
- engine
environment:
- EMBEDDING_DOCKER_URL=tcp://engine:2375
- EMBEDDING_CONTAINER_IMAGE=embedding-atlas:latest
- EMBEDDING_CONTAINER_NAME_PREFIX=embedding-atlas
- EMBEDDING_API_HOST=0.0.0.0
- EMBEDDING_API_PORT=9000
- EMBEDDING_SESSION_ROOT=/sessions
volumes:
- sessions-data:/sessions
ports:
- "9000:9000"
restart: unless-stopped
volumes:
engine-data:
sessions-data:

View File

@@ -0,0 +1,17 @@
FROM python:3.12-slim
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
build-essential \
curl \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir embedding-atlas
WORKDIR /app
EXPOSE 5055
CMD ["embedding-atlas"]

11
fastmcp.json Normal file
View File

@@ -0,0 +1,11 @@
{
"mcpServers": {
"embedding-atlas-session-manager": {
"command": "embedding-backend-mcp",
"args": [],
"env": {},
"timeout": 30000,
"description": "Start the embedding orchestrator in MCP stdio mode"
}
}
}

10
pixi.toml Normal file
View File

@@ -0,0 +1,10 @@
[workspace]
authors = ["lingyuzeng <pylyzeng@gmail.com>"]
channels = ["conda-forge"]
name = "embedding_atlas"
platforms = ["osx-arm64"]
version = "0.1.0"
[tasks]
[dependencies]

View File

@@ -1,3 +1,4 @@
[project]
name = "mole-embedding-atlas"
version = "0.1.0"
@@ -12,4 +13,31 @@ dependencies = [
"rdkit",
"pandas",
"selfies==2.1.1",
"fastapi>=0.111",
"uvicorn[standard]>=0.29",
"fastmcp>=2.11",
"docker>=7.1",
"httpx>=0.27",
"pydantic-settings>=2.2",
]
[project.scripts]
embedding-backend = "embedding_backend.main:main"
embedding-backend-api = "embedding_backend.main:run_api"
embedding-backend-mcp = "embedding_backend.main:run_mcp_stdio"
[tool.uv.pip]
index-url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"
[tool.setuptools]
package-dir = {"" = "src"}
[tool.setuptools.packages.find]
where = ["src"]
[build-system]
requires = ["setuptools>=68", "wheel"]
build-backend = "setuptools.build_meta"
[tool.uv]
package = true

View File

@@ -0,0 +1,186 @@
#!/usr/bin/env python3
"""Augment a CSV with ECFP4 binary fingerprints and Tanimoto neighbor summaries."""
from __future__ import annotations
import argparse
import pathlib
from dataclasses import dataclass, field
from typing import Iterable, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
from rdkit import Chem, DataStructs
from rdkit.Chem import rdFingerprintGenerator
@dataclass
class ECFP4Generator:
"""Generate ECFP4 (Morgan radius 2) fingerprints as RDKit bit vectors."""
n_bits: int = 2048
radius: int = 2
include_chirality: bool = True
generator: rdFingerprintGenerator.MorganGenerator = field(init=False)
def __post_init__(self) -> None:
self.generator = rdFingerprintGenerator.GetMorganGenerator(
radius=self.radius,
fpSize=self.n_bits,
includeChirality=self.include_chirality,
)
def fingerprint(self, smiles: str) -> Optional[DataStructs.ExplicitBitVect]:
if not smiles:
return None
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
try:
return self.generator.GetFingerprint(mol)
except Exception:
return None
def to_binary_string(self, fp: DataStructs.ExplicitBitVect) -> str:
arr = np.zeros((self.n_bits,), dtype=np.uint8)
DataStructs.ConvertToNumpyArray(fp, arr)
return ''.join(arr.astype(str))
def tanimoto_top_k(
fingerprints: Sequence[Optional[DataStructs.ExplicitBitVect]],
ids: Sequence[str],
top_k: int = 5,
) -> List[str]:
"""Return semicolon-delimited Tanimoto summaries for each fingerprint."""
valid_indices = [idx for idx, fp in enumerate(fingerprints) if fp is not None]
valid_fps = [fingerprints[idx] for idx in valid_indices]
index_lookup = {pos: original for pos, original in enumerate(valid_indices)}
summaries = [''] * len(fingerprints)
if not valid_fps:
return summaries
for pos, fp in enumerate(valid_fps):
sims = DataStructs.BulkTanimotoSimilarity(fp, valid_fps)
ranked: List[Tuple[int, float]] = []
for other_pos, score in enumerate(sims):
if other_pos == pos:
continue
ranked.append((other_pos, score))
ranked.sort(key=lambda item: item[1], reverse=True)
top_entries = []
for other_pos, score in ranked[:top_k]:
original_idx = index_lookup[other_pos]
top_entries.append(f"{ids[original_idx]}:{score:.4f}")
summaries[index_lookup[pos]] = ';'.join(top_entries)
return summaries
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input_csv", type=pathlib.Path, help="Source CSV with SMILES")
parser.add_argument(
"--output",
type=pathlib.Path,
default=None,
help="Destination file (default: <input>_with_ecfp4.parquet)",
)
parser.add_argument(
"--smiles-column",
default="smiles",
help="Name of the column containing SMILES (default: smiles)",
)
parser.add_argument(
"--id-column",
default="generated_id",
help="Column used to label Tanimoto neighbors (default: generated_id)",
)
parser.add_argument(
"--top-k",
type=int,
default=5,
help="Number of nearest neighbors to report in Tanimoto summaries (default: 5)",
)
parser.add_argument(
"--format",
choices=("parquet", "csv", "auto"),
default="parquet",
help="Output format; defaults to parquet unless overridden or inferred from --output",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not args.input_csv.exists():
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
def resolve_output_path() -> pathlib.Path:
if args.output is not None:
return args.output
suffix = ".parquet" if args.format in ("parquet", "auto") else ".csv"
return args.input_csv.with_name(f"{args.input_csv.stem}_with_ecfp4{suffix}")
def resolve_format(path: pathlib.Path) -> str:
suffix = path.suffix.lower()
if suffix in {".parquet", ".pq"}:
return "parquet"
if suffix == ".csv":
return "csv"
if args.format == "parquet":
return "parquet"
if args.format == "csv":
return "csv"
raise ValueError(
"无法根据输出文件推断格式,请为 --output 指定 .parquet/.csv 后缀或使用 --format",
)
output_path = resolve_output_path()
output_format = resolve_format(output_path)
if args.input_csv.resolve() == output_path.resolve():
raise ValueError("Output path must differ from input path to avoid overwriting input.")
df = pd.read_csv(args.input_csv)
if args.smiles_column not in df.columns:
raise ValueError(f"Column '{args.smiles_column}' not found in input data")
smiles_series = df[args.smiles_column].fillna('')
if args.id_column in df.columns:
ids = df[args.id_column].astype(str).tolist()
else:
ids = [f"D{idx:06d}" for idx in range(1, len(df) + 1)]
df[args.id_column] = ids
generator = ECFP4Generator()
fingerprints: List[Optional[DataStructs.ExplicitBitVect]] = []
binary_repr: List[str] = []
for smiles in smiles_series:
fp = generator.fingerprint(smiles)
fingerprints.append(fp)
binary_repr.append(generator.to_binary_string(fp) if fp is not None else '')
df['ecfp4_binary'] = binary_repr
df['tanimoto_top_neighbors'] = tanimoto_top_k(fingerprints, ids, top_k=args.top_k)
if output_format == "parquet":
df.to_parquet(output_path, index=False)
else:
df.to_csv(output_path, index=False)
print(f"Wrote augmented data with {len(df)} rows to {output_path} ({output_format})")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,127 @@
#!/usr/bin/env python3
"""Add generated IDs and macrolactone ring size annotations to a CSV file."""
import argparse
import pathlib
import warnings
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Optional, Set, Tuple
import pandas as pd
from rdkit import Chem
@dataclass
class MacrolactoneRingDetector:
"""Detect macrolactone rings in SMILES strings via SMARTS patterns."""
min_size: int = 12
max_size: int = 20
patterns: Dict[int, Optional[Chem.Mol]] = field(init=False, default_factory=dict)
def __post_init__(self) -> None:
self.patterns = {
size: Chem.MolFromSmarts(f"[r{size}]([#8][#6](=[#8]))")
for size in range(self.min_size, self.max_size + 1)
}
def ring_sizes(self, smiles: str) -> List[int]:
"""Return a sorted list of macrolactone ring sizes present in the SMILES."""
if not smiles:
return []
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return []
ring_atoms = mol.GetRingInfo().AtomRings()
if not ring_atoms:
return []
matched_rings: Set[Tuple[int, ...]] = set()
matched_sizes: Set[int] = set()
for size, query in self.patterns.items():
if query is None:
continue
for match in mol.GetSubstructMatches(query, uniquify=True):
ring = self._pick_ring(match, ring_atoms, size)
if ring and ring not in matched_rings:
matched_rings.add(ring)
matched_sizes.add(size)
sizes = sorted(matched_sizes)
if len(sizes) > 1:
warnings.warn(
"Multiple macrolactone ring sizes detected",
RuntimeWarning,
stacklevel=2,
)
print(f"Multiple ring sizes {sizes} for SMILES: {smiles}")
return sizes
@staticmethod
def _pick_ring(
match: Tuple[int, ...], rings: Iterable[Tuple[int, ...]], expected_size: int
) -> Optional[Tuple[int, ...]]:
ring_atom = match[0]
for ring in rings:
if len(ring) == expected_size and ring_atom in ring:
return tuple(sorted(ring))
return None
def add_columns(df: pd.DataFrame, detector: MacrolactoneRingDetector) -> pd.DataFrame:
result = df.copy()
result["generated_id"] = [f"D{index:06d}" for index in range(1, len(result) + 1)]
smiles_series = result["smiles"].fillna("") if "smiles" in result.columns else pd.Series(
[""] * len(result), index=result.index
)
def format_sizes(smiles: str) -> str:
sizes = detector.ring_sizes(smiles)
return ";".join(str(size) for size in sizes) if sizes else ""
result["macrocycle_ring_sizes"] = smiles_series.apply(format_sizes)
return result
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input_csv", type=pathlib.Path, help="Path to the source CSV file")
parser.add_argument(
"--output",
type=pathlib.Path,
default=None,
help="Destination for the augmented CSV (default: <input>_with_macrocycles.csv)",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not args.input_csv.exists():
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
output_path = args.output or args.input_csv.with_name(
f"{args.input_csv.stem}_with_macrocycles.csv"
)
if args.input_csv.resolve() == output_path.resolve():
raise ValueError("Output path must differ from input path to avoid overwriting.")
df = pd.read_csv(args.input_csv)
detector = MacrolactoneRingDetector()
augmented = add_columns(df, detector)
augmented.to_csv(output_path, index=False)
print(f"Wrote augmented data with {len(augmented)} rows to {output_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,439 @@
#!/usr/bin/env python3
"""
Optimized ECFP4 Fingerprinting with UMAP Visualization for Macrolactone Molecules
This script processes SMILES data to:
1. Generate ECFP4 fingerprints using RDKit
2. Detect ring numbers in macrolactone molecules using SMARTS patterns
3. Generate unique IDs for molecules without existing IDs
4. Perform UMAP dimensionality reduction with Tanimoto distance
5. Prepare data for embedding-atlas visualization
Optimized for large datasets with progress tracking and memory efficiency.
"""
import os
import sys
import argparse
import subprocess
from typing import Optional, List
# RDKit imports
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors, DataStructs
from rdkit.Chem.MolStandardize import rdMolStandardize
# Data processing
import pandas as pd
import numpy as np
# UMAP and visualization
import umap
import matplotlib.pyplot as plt
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')
# Progress bar
try:
from tqdm import tqdm
HAS_TQDM = True
except ImportError:
HAS_TQDM = False
class MacrolactoneProcessor:
"""Process macrolactone molecules for embedding visualization."""
def __init__(self, n_bits: int = 2048, radius: int = 2, chirality: bool = True):
"""
Initialize processor with ECFP4 parameters.
Args:
n_bits: Number of fingerprint bits (default: 2048)
radius: Morgan fingerprint radius (default: 2 for ECFP4)
chirality: Include chirality information (default: True)
"""
self.n_bits = n_bits
self.radius = radius
self.chirality = chirality
# Standardizer for molecule preprocessing
self.standardizer = rdMolStandardize.MetalDisconnector()
# SMARTS patterns for different ring sizes (12-20 membered rings)
self.ring_smarts = {
12: '[r12][#8][#6](=[#8])', # 12-membered ring with lactone
13: '[r13][#8][#6](=[#8])', # 13-membered ring with lactone
14: '[r14][#8][#6](=[#8])', # 14-membered ring with lactone
15: '[r15][#8][#6](=[#8])', # 15-membered ring with lactone
16: '[r16][#8][#6](=[#8])', # 16-membered ring with lactone
17: '[r17][#8][#6](=[#8])', # 17-membered ring with lactone
18: '[r18][#8][#6](=[#8])', # 18-membered ring with lactone
19: '[r19][#8][#6](=[#8])', # 19-membered ring with lactone
20: '[r20][#8][#6](=[#8])', # 20-membered ring with lactone
}
def standardize_molecule(self, mol: Chem.Mol) -> Optional[Chem.Mol]:
"""Standardize molecule using RDKit standardization."""
try:
# Remove metals
mol = self.standardizer.Disconnect(mol)
# Normalize
mol = rdMolStandardize.Normalize(mol)
# Remove fragments
mol = rdMolStandardize.FragmentParent(mol)
# Neutralize charges
mol = rdMolStandardize.ChargeParent(mol)
return mol
except:
return None
def ecfp4_fingerprint(self, smiles: str) -> Optional[np.ndarray]:
"""Generate ECFP4 fingerprint from SMILES string using newer RDKit API."""
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
# Standardize molecule
mol = self.standardize_molecule(mol)
if mol is None:
return None
# Generate Morgan fingerprint using the newer API to avoid deprecation warnings
from rdkit.Chem import rdFingerprintGenerator
generator = rdFingerprintGenerator.GetMorganGenerator(
radius=self.radius,
fpSize=self.n_bits,
includeChirality=self.chirality
)
bv = generator.GetFingerprint(mol)
# Convert to numpy array
arr = np.zeros((self.n_bits,), dtype=np.uint8)
DataStructs.ConvertToNumpyArray(bv, arr)
return arr
except Exception as e:
print(f"Error processing SMILES {smiles[:50]}...: {e}")
return None
def detect_ring_number(self, smiles: str) -> int:
"""Detect the ring number in macrolactone molecule using SMARTS patterns."""
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return 0
# Check each ring size pattern
for ring_size, smarts in self.ring_smarts.items():
query = Chem.MolFromSmarts(smarts)
if query:
matches = mol.GetSubstructMatches(query)
if matches:
return ring_size
# Alternative: check for any large ring with lactone
generic_pattern = Chem.MolFromSmarts('[r{12-20}][#8][#6](=[#8])')
if generic_pattern:
matches = mol.GetSubstructMatches(generic_pattern)
if matches:
# Try to determine ring size from the first match
for match in matches:
# Get the ring atoms
for atom_idx in match:
atom = mol.GetAtomWithIdx(atom_idx)
if atom.IsInRing():
# Find the ring size
for ring in atom.GetOwningMol().GetRingInfo().AtomRings():
if atom_idx in ring:
ring_size = len(ring)
if 12 <= ring_size <= 20:
return ring_size
return 0
except Exception as e:
print(f"Error detecting ring number for {smiles}: {e}")
return 0
def generate_unique_id(self, index: int, existing_id: Optional[str] = None) -> str:
"""Generate unique ID for molecule."""
if existing_id and pd.notna(existing_id) and existing_id != '':
return str(existing_id)
else:
return f"D{index:07d}"
def tanimoto_similarity(self, fp1: np.ndarray, fp2: np.ndarray) -> float:
"""Calculate Tanimoto similarity between two fingerprints."""
# Bit count
bit_count1 = np.sum(fp1)
bit_count2 = np.sum(fp2)
common_bits = np.sum(fp1 & fp2)
if bit_count1 + bit_count2 - common_bits == 0:
return 0.0
return common_bits / (bit_count1 + bit_count2 - common_bits)
def find_neighbors(self, X: np.ndarray, k: int = 15, batch_size: int = 1000) -> List[str]:
"""Find k nearest neighbors for each molecule based on Tanimoto similarity."""
n_samples = X.shape[0]
neighbors = []
# Progress bar
if HAS_TQDM:
pbar = tqdm(total=n_samples, desc="Finding neighbors")
for i in range(n_samples):
similarities = []
# Batch processing for memory efficiency
for j in range(0, n_samples, batch_size):
end_j = min(j + batch_size, n_samples)
batch_X = X[j:end_j]
# Calculate similarities for this batch
for batch_idx, fp in enumerate(batch_X):
orig_idx = j + batch_idx
if i != orig_idx:
sim = self.tanimoto_similarity(X[i], fp)
similarities.append((orig_idx, sim))
# Sort by similarity (descending)
similarities.sort(key=lambda x: x[1], reverse=True)
# Get top k neighbors
top_neighbors = [str(idx) for idx, _ in similarities[:k]]
neighbors.append(','.join(top_neighbors))
if HAS_TQDM:
pbar.update(1)
if HAS_TQDM:
pbar.close()
return neighbors
def perform_umap(self, X: np.ndarray, n_neighbors: int = 30,
min_dist: float = 0.1, metric: str = 'jaccard') -> np.ndarray:
"""Perform UMAP dimensionality reduction."""
reducer = umap.UMAP(
n_neighbors=n_neighbors,
min_dist=min_dist,
metric=metric,
random_state=42
)
return reducer.fit_transform(X)
def process_dataframe(self, df: pd.DataFrame, smiles_col: str = 'smiles',
id_col: Optional[str] = None, max_molecules: Optional[int] = None) -> pd.DataFrame:
"""Process dataframe with SMILES strings."""
print(f"Processing {len(df)} molecules...")
# Limit molecules if requested
if max_molecules:
df = df.head(max_molecules)
print(f"Limited to {max_molecules} molecules")
# Ensure we have a smiles column
if smiles_col not in df.columns:
raise ValueError(f"Column '{smiles_col}' not found in dataframe")
# Create a working copy
result_df = df.copy()
# Generate unique IDs if needed
if id_col and id_col in df.columns:
result_df['molecule_id'] = [self.generate_unique_id(i, existing_id)
for i, existing_id in enumerate(result_df[id_col])]
else:
result_df['molecule_id'] = [self.generate_unique_id(i)
for i in range(len(result_df))]
# Process fingerprints
print("Generating ECFP4 fingerprints...")
fingerprints = []
valid_indices = []
# Progress tracking
iterator = enumerate(result_df[smiles_col])
if HAS_TQDM:
iterator = tqdm(iterator, total=len(result_df), desc="Processing fingerprints")
for idx, smiles in iterator:
if pd.notna(smiles) and smiles != '':
fp = self.ecfp4_fingerprint(smiles)
if fp is not None:
fingerprints.append(fp)
valid_indices.append(idx)
else:
print(f"Failed to generate fingerprint for index {idx}: {smiles[:50]}...")
else:
print(f"Invalid SMILES at index {idx}")
# Filter dataframe to valid molecules only
result_df = result_df.iloc[valid_indices].reset_index(drop=True)
if not fingerprints:
raise ValueError("No valid fingerprints generated")
# Convert fingerprints to numpy array
X = np.array(fingerprints)
print(f"Generated fingerprints for {len(fingerprints)} molecules")
# Detect ring numbers
print("Detecting ring numbers...")
ring_numbers = []
iterator = result_df[smiles_col]
if HAS_TQDM:
iterator = tqdm(iterator, desc="Detecting rings")
for smiles in iterator:
ring_num = self.detect_ring_number(smiles)
ring_numbers.append(ring_num)
result_df['ring_num'] = ring_numbers
# Perform UMAP
print("Performing UMAP dimensionality reduction...")
embedding = self.perform_umap(X)
result_df['projection_x'] = embedding[:, 0]
result_df['projection_y'] = embedding[:, 1]
# Find neighbors for embedding-atlas
print("Finding nearest neighbors...")
neighbors = self.find_neighbors(X, k=15)
result_df['neighbors'] = neighbors
# Add fingerprint information
result_df['fingerprint_bits'] = [fp.tolist() for fp in fingerprints]
return result_df
def create_visualization(self, df: pd.DataFrame, output_path: str):
"""Create visualization of the UMAP embedding."""
plt.figure(figsize=(12, 8))
# Color by ring number
scatter = plt.scatter(df['projection_x'], df['projection_y'],
c=df['ring_num'], cmap='viridis', alpha=0.6, s=30)
plt.colorbar(scatter, label='Ring Number')
plt.xlabel('UMAP 1')
plt.ylabel('UMAP 2')
plt.title('Macrolactone Molecules - ECFP4 + UMAP Visualization')
# Add some annotations for ring numbers
for ring_num in sorted(df['ring_num'].unique()):
if ring_num > 0:
subset = df[df['ring_num'] == ring_num]
if len(subset) > 0:
center_x = subset['projection_x'].mean()
center_y = subset['projection_y'].mean()
plt.annotate(f'{ring_num} ring', (center_x, center_y),
fontsize=10, fontweight='bold')
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Visualization saved to {output_path}")
def main():
"""Main function to run the processing pipeline."""
parser = argparse.ArgumentParser(description='ECFP4 + UMAP for Macrolactone Molecules')
parser.add_argument('--input', '-i', required=True,
help='Input CSV file path')
parser.add_argument('--output', '-o', required=True,
help='Output CSV file path')
parser.add_argument('--smiles-col', default='smiles',
help='Name of SMILES column (default: smiles)')
parser.add_argument('--id-col', default=None,
help='Name of ID column (optional)')
parser.add_argument('--visualization', '-v', default='umap_visualization.png',
help='Output visualization file path')
parser.add_argument('--max-molecules', type=int, default=None,
help='Maximum number of molecules to process (for testing)')
parser.add_argument('--launch-atlas', action='store_true',
help='Launch embedding-atlas process')
parser.add_argument('--atlas-port', type=int, default=8080,
help='Port for embedding-atlas server')
args = parser.parse_args()
# Initialize processor
processor = MacrolactoneProcessor(n_bits=2048, radius=2, chirality=True)
# Load data
print(f"Loading data from {args.input}")
try:
df = pd.read_csv(args.input)
print(f"Loaded {len(df)} molecules")
print(f"Columns: {list(df.columns)}")
except Exception as e:
print(f"Error loading data: {e}")
return 1
# Process dataframe
try:
processed_df = processor.process_dataframe(df,
smiles_col=args.smiles_col,
id_col=args.id_col,
max_molecules=args.max_molecules)
print(f"Successfully processed {len(processed_df)} molecules")
except Exception as e:
print(f"Error processing data: {e}")
return 1
# Save results
try:
processed_df.to_csv(args.output, index=False)
print(f"Results saved to {args.output}")
except Exception as e:
print(f"Error saving results: {e}")
return 1
# Create visualization
try:
processor.create_visualization(processed_df, args.visualization)
except Exception as e:
print(f"Error creating visualization: {e}")
# Launch embedding-atlas if requested
if args.launch_atlas:
print("Launching embedding-atlas process...")
try:
# Prepare command for embedding-atlas
cmd = [
'embedding-atlas', 'data', args.output,
'--text', args.smiles_col,
'--port', str(args.atlas_port),
'--neighbors', 'neighbors',
'--x', 'projection_x',
'--y', 'projection_y'
]
print(f"Running command: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
print("Embedding-atlas process launched successfully")
print(f"Access the visualization at: http://localhost:{args.atlas_port}")
else:
print(f"Error launching embedding-atlas: {result.stderr}")
except FileNotFoundError:
print("embedding-atlas command not found. Please install it first.")
print("You can install it with: pip install embedding-atlas")
except Exception as e:
print(f"Error launching embedding-atlas: {e}")
print("Processing complete!")
return 0
if __name__ == '__main__':
sys.exit(main())

View File

@@ -0,0 +1,5 @@
"""Embedding Atlas session orchestrator package."""
from .main import create_app, get_mcp_server
__all__ = ["create_app", "get_mcp_server"]

View File

@@ -0,0 +1,90 @@
"""Configuration for the embedding session orchestrator."""
from __future__ import annotations
from pathlib import Path
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings loaded from the environment."""
docker_url: str = Field(
default="unix:///var/run/docker.sock",
description="Docker socket URL accessed by the orchestrator.",
)
container_image: str = Field(
default="embedding-atlas:latest",
description="Default Docker image used for embedding-atlas sessions.",
)
container_name_prefix: str = Field(
default="embedding-atlas",
description="Prefix applied to managed container names.",
)
container_host: str = Field(
default="0.0.0.0",
description="Host binding used inside managed containers.",
)
api_host: str = Field(
default="0.0.0.0",
description="Listening host for the FastAPI server.",
)
api_port: int = Field(
default=9000,
description="Listening port for the FastAPI server.",
)
default_port: int = Field(
default=5055,
description="Fallback embedding-atlas port when no override is provided.",
)
port_range_start: int = Field(
default=6000,
description="Lower bound for dynamically assigned ports.",
)
port_range_end: int = Field(
default=6999,
description="Upper bound for dynamically assigned ports (inclusive).",
)
session_root: Path = Field(
default=Path("runtime/sessions"),
description="Directory used to persist downloaded datasets per session.",
)
download_chunk_size: int = Field(
default=64 * 1024,
description="Chunk size in bytes when streaming remote datasets.",
)
download_timeout_seconds: int = Field(
default=300,
description="Timeout applied to dataset downloads.",
)
auto_remove_seconds: int = Field(
default=36000,
description="Default inactivity window before a session is garbage-collected (10h).",
)
cleanup_interval_seconds: int = Field(
default=600,
description="Background cleanup cadence in seconds.",
)
container_network: Optional[str] = Field(
default=None,
description="Optional Docker network to attach managed containers to.",
)
model_config = SettingsConfigDict(
env_prefix="EMBEDDING_",
env_file=".env",
env_file_encoding="utf-8",
)
def prepare(self) -> None:
"""Ensure directories referenced by the settings exist."""
self.session_root = self.session_root.resolve()
self.session_root.mkdir(parents=True, exist_ok=True)
settings = Settings()
settings.prepare()

View File

@@ -0,0 +1,376 @@
"""Asynchronous session manager backed by Docker containers."""
from __future__ import annotations
import asyncio
import logging
import socket
from contextlib import suppress
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Dict, Iterable, Optional
import docker
import httpx
from docker.errors import APIError, NotFound
from docker.models.containers import Container
from .config import Settings, settings
from .models import (
DATASET_LABEL,
EXPIRY_LABEL,
MANAGED_LABEL,
SESSION_ID_LABEL,
SessionCreateRequest,
SessionRecord,
default_labels,
)
logger = logging.getLogger(__name__)
class SessionError(RuntimeError):
"""Base class for orchestration failures."""
class SessionExistsError(SessionError):
"""Raised when a duplicate session identifier is requested."""
class SessionNotFoundError(SessionError):
"""Raised when attempting to operate on a missing session."""
class DatasetDownloadError(SessionError):
"""Raised when the dataset cannot be fetched."""
class PortUnavailableError(SessionError):
"""Raised when the requested port is already in use."""
class SessionManager:
"""Coordinate embedding-atlas containers per incoming request."""
def __init__(self, app_settings: Settings | None = None) -> None:
self.settings = app_settings or settings
self.client = docker.DockerClient(base_url=self.settings.docker_url)
self._sessions: Dict[str, SessionRecord] = {}
self._lock = asyncio.Lock()
self._cleanup_task: Optional[asyncio.Task[None]] = None
async def start(self) -> None:
"""Initialise caches and background cleanup."""
await self._load_existing_sessions()
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def stop(self) -> None:
"""Stop background tasks and close Docker handles."""
if self._cleanup_task is not None:
self._cleanup_task.cancel()
with suppress(asyncio.CancelledError):
await self._cleanup_task
self._cleanup_task = None
await asyncio.to_thread(self.client.close)
async def create_session(self, payload: SessionCreateRequest) -> SessionRecord:
"""Download the dataset and spin up an embedding-atlas container."""
session_id = payload.normalised_session_id()
async with self._lock:
if session_id in self._sessions:
raise SessionExistsError(f"Session {session_id} already active.")
host = payload.host or self.settings.container_host
if payload.port is not None:
is_free = await asyncio.to_thread(self._port_available, payload.port)
if not is_free:
raise PortUnavailableError(f"Requested port {payload.port} is unavailable")
port = payload.port
else:
port = await asyncio.to_thread(self._allocate_port)
expiry_seconds = payload.auto_remove_seconds or self.settings.auto_remove_seconds
expires_at = datetime.now(tz=timezone.utc) + timedelta(seconds=expiry_seconds)
dataset_path = await self._download_dataset(session_id, payload)
try:
record = await asyncio.to_thread(
self._run_container,
session_id,
dataset_path,
payload,
host,
port,
expires_at,
)
except Exception:
# Drop downloaded artefacts if startup failed
with suppress(Exception):
self._purge_dataset(dataset_path)
raise
async with self._lock:
self._sessions[session_id] = record
return record
async def close_session(self, session_id: str) -> bool:
"""Terminate the associated container and remove cached data."""
record = await self._get_session(session_id)
await asyncio.to_thread(self._stop_container, record.container_name)
await asyncio.to_thread(self._purge_dataset, Path(record.dataset_path))
async with self._lock:
self._sessions.pop(session_id, None)
return True
async def list_sessions(self) -> Iterable[SessionRecord]:
"""Return the internal session registry."""
async with self._lock:
return list(self._sessions.values())
async def cleanup_expired(self) -> None:
"""Stop containers that exceeded their TTL."""
now = datetime.now(tz=timezone.utc)
async with self._lock:
sessions = list(self._sessions.values())
for record in sessions:
if record.expires_at <= now:
logger.info("Session %s expired; stopping container %s", record.session_id, record.container_name)
with suppress(SessionError, APIError, NotFound):
await self.close_session(record.session_id)
async def _cleanup_loop(self) -> None:
"""Periodic cleanup coroutine."""
try:
while True:
await asyncio.sleep(self.settings.cleanup_interval_seconds)
await self.cleanup_expired()
except asyncio.CancelledError:
logger.debug("Cleanup loop cancelled")
raise
async def _get_session(self, session_id: str) -> SessionRecord:
async with self._lock:
record = self._sessions.get(session_id)
if record is None:
raise SessionNotFoundError(f"Session {session_id} not found")
return record
async def _load_existing_sessions(self) -> None:
"""Recover managed containers into the local cache on startup."""
def _list() -> list[Container]:
return self.client.containers.list(all=True, filters={"label": MANAGED_LABEL})
containers = await asyncio.to_thread(_list)
restored: Dict[str, SessionRecord] = {}
for container in containers:
labels = container.labels or {}
session_id = labels.get(SESSION_ID_LABEL)
if not session_id:
continue
try:
port = self._extract_port(container)
except ValueError:
logger.warning("Unable to restore port mapping for container %s", container.name)
continue
expires_at = self._parse_expiry(labels.get(EXPIRY_LABEL))
dataset_path = labels.get(DATASET_LABEL, "")
restored[session_id] = SessionRecord(
session_id=session_id,
container_id=container.id,
container_name=container.name,
port=port,
host=self.settings.container_host,
started_at=self._parse_created(container.attrs.get("Created")),
expires_at=expires_at,
dataset_path=dataset_path,
)
async with self._lock:
self._sessions.update(restored)
if restored:
logger.info("Restored %d managed sessions from existing containers", len(restored))
def _run_container(
self,
session_id: str,
dataset_path: Path,
payload: SessionCreateRequest,
host: str,
port: int,
expires_at: datetime,
) -> SessionRecord:
container_name = f"{self.settings.container_name_prefix}_{session_id}"
self._guard_duplicate_container(container_name)
bind_root = dataset_path.parent
container_dataset_path = Path("/session") / dataset_path.name
command = [
"embedding-atlas",
str(container_dataset_path),
"--host",
host,
"--port",
str(port),
"--no-auto-port",
]
if payload.extra_args:
command.extend(payload.extra_args)
env = payload.environment.copy()
env.setdefault("SESSION_ID", session_id)
labels = default_labels(session_id, expires_at, str(dataset_path))
labels.update(payload.labels)
logger.info(
"Starting embedding-atlas container %s on port %s (session %s)",
container_name,
port,
session_id,
)
port_bindings = {f"{port}/tcp": port}
container = self.client.containers.run(
payload.image or self.settings.container_image,
command,
name=container_name,
detach=True,
environment=env,
labels=labels,
volumes={
str(bind_root): {"bind": "/session", "mode": "ro"}
},
ports=port_bindings,
network=self.settings.container_network,
)
return SessionRecord(
session_id=session_id,
container_id=container.id,
container_name=container.name,
port=port,
host=host,
started_at=datetime.now(tz=timezone.utc),
expires_at=expires_at,
dataset_path=str(dataset_path),
)
def _stop_container(self, container_name: str) -> None:
try:
container = self.client.containers.get(container_name)
except NotFound as exc:
raise SessionNotFoundError(f"Container {container_name} not found") from exc
container.stop(timeout=30)
container.remove(v=True)
logger.info("Removed container %s", container_name)
def _purge_dataset(self, dataset_path: Path) -> None:
try:
dataset_path.unlink(missing_ok=True)
parent = dataset_path.parent
if parent.exists() and not any(parent.iterdir()):
parent.rmdir()
except Exception as exc:
logger.warning("Failed to clean dataset path %s: %s", dataset_path, exc)
async def _download_dataset(self, session_id: str, payload: SessionCreateRequest) -> Path:
filename = payload.input_filename or Path(payload.data_url.path).name
if not filename:
filename = f"dataset-{session_id}"
session_dir = self.settings.session_root / session_id
session_dir.mkdir(parents=True, exist_ok=True)
target = session_dir / filename
logger.info("Downloading dataset for session %s from %s", session_id, payload.data_url)
timeout = httpx.Timeout(self.settings.download_timeout_seconds)
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
try:
async with client.stream("GET", str(payload.data_url)) as response:
response.raise_for_status()
with target.open("wb") as file_handle:
async for chunk in response.aiter_bytes(chunk_size=self.settings.download_chunk_size):
file_handle.write(chunk)
except Exception as exc:
raise DatasetDownloadError(f"Failed to fetch dataset: {exc}") from exc
return target
def _allocate_port(self) -> int:
if self.settings.port_range_start >= self.settings.port_range_end:
return self._random_free_port()
candidates = range(self.settings.port_range_start, self.settings.port_range_end + 1)
for candidate in candidates:
if self._port_available(candidate):
return candidate
return self._random_free_port()
@staticmethod
def _random_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
@staticmethod
def _port_available(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
result = sock.connect_ex(("127.0.0.1", port))
return result != 0
def _guard_duplicate_container(self, container_name: str) -> None:
try:
container = self.client.containers.get(container_name)
except NotFound:
return
raise SessionExistsError(
f"Container name {container_name} already exists (id={container.id})."
)
@staticmethod
def _extract_port(container: Container) -> int:
ports = container.attrs.get("NetworkSettings", {}).get("Ports", {})
if not ports:
raise ValueError("No exposed ports found")
# Expect single port binding
(container_port, host_bindings), = ports.items()
if not host_bindings:
raise ValueError("Container port not published")
host_port = host_bindings[0].get("HostPort")
if host_port is None:
raise ValueError("Missing host port")
return int(host_port)
@staticmethod
def _parse_expiry(value: Optional[str]) -> datetime:
if not value:
return datetime.now(tz=timezone.utc)
try:
return datetime.fromtimestamp(int(value), tz=timezone.utc)
except ValueError:
return datetime.now(tz=timezone.utc)
@staticmethod
def _parse_created(value: Optional[str]) -> datetime:
if not value:
return datetime.now(tz=timezone.utc)
try:
return datetime.fromisoformat(value.replace('Z', '+00:00'))
except ValueError:
return datetime.now(tz=timezone.utc)
__all__ = [
"SessionManager",
"SessionError",
"SessionExistsError",
"SessionNotFoundError",
"DatasetDownloadError",
"PortUnavailableError",
]

View File

@@ -0,0 +1,120 @@
"""Entry points for launching the orchestrator as FastAPI or FastMCP."""
from __future__ import annotations
import argparse
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import Optional
import uvicorn
from fastapi import FastAPI
from fastmcp import FastMCP
from .config import Settings, settings
from .docker_manager import SessionManager
from .routes import get_router
logger = logging.getLogger(__name__)
def create_app(app_settings: Optional[Settings] = None) -> FastAPI:
"""Create and configure the FastAPI application."""
app_settings = app_settings or settings
manager = SessionManager(app_settings)
@asynccontextmanager
async def lifespan(_: FastAPI):
await manager.start()
try:
yield
finally:
await manager.stop()
app = FastAPI(
title="Embedding Atlas Orchestrator",
version="0.1.0",
description=(
"Session orchestration API that launches short-lived embedding-atlas containers on demand."
),
docs_url="/docs",
redoc_url=None,
openapi_url="/openapi.json",
lifespan=lifespan,
)
@app.get("/", tags=["meta"], summary="Service metadata")
async def root() -> dict[str, str]:
return {"message": "Embedding Atlas session manager is running. Visit /docs for the Swagger UI."}
app.include_router(get_router(manager))
app.state.session_manager = manager
return app
_app = create_app()
app = _app
_mcp_server = FastMCP.from_fastapi(_app, name="embedding-atlas-session-manager")
mcp_server = _mcp_server
def get_mcp_server() -> FastMCP:
return mcp_server
def run_api() -> None:
"""Launch the FastAPI server via uvicorn."""
uvicorn.run(
app,
host=settings.api_host,
port=settings.api_port,
)
async def _run_mcp_stdio_async() -> None:
server = get_mcp_server()
await server.run_stdio_async()
def run_mcp_stdio() -> None:
"""Launch the MCP server over stdio (experimental)."""
asyncio.run(_run_mcp_stdio_async())
def main() -> None:
"""CLI entry point allowing API or MCP execution modes."""
parser = argparse.ArgumentParser(description="Embedding Atlas orchestrator")
parser.add_argument(
"--mode",
choices=["api", "mcp"],
default="api",
help="Select whether to launch the HTTP API or the MCP stdio server.",
)
parser.add_argument(
"--host",
default=settings.api_host,
help="Override FastAPI bind host.",
)
parser.add_argument(
"--port",
type=int,
default=settings.api_port,
help="Override FastAPI bind port.",
)
args = parser.parse_args()
if args.mode == "api":
logger.info("Starting FastAPI server on %s:%s", args.host, args.port)
uvicorn.run(app, host=args.host, port=args.port)
else:
logger.info("Starting FastMCP stdio server")
run_mcp_stdio()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,122 @@
"""Pydantic models for the orchestration API."""
from __future__ import annotations
from datetime import datetime
from typing import Dict, List, Optional
from uuid import uuid4
from pydantic import BaseModel, Field, HttpUrl
class SessionCreateRequest(BaseModel):
"""Incoming payload describing a requested embedding-atlas session."""
session_id: Optional[str] = Field(
default=None,
description="Optional session identifier; generated when omitted.",
)
data_url: HttpUrl = Field(
description="HTTPS location of the dataset to visualize.",
)
input_filename: Optional[str] = Field(
default=None,
description="Filename to persist the downloaded dataset under.",
)
extra_args: List[str] = Field(
default_factory=list,
description="Additional CLI arguments forwarded to embedding-atlas.",
)
environment: Dict[str, str] = Field(
default_factory=dict,
description="Environment variables injected into the managed container.",
)
labels: Dict[str, str] = Field(
default_factory=dict,
description="Extra Docker labels assigned to the managed container.",
)
image: Optional[str] = Field(
default=None,
description="Override Docker image for the managed container.",
)
host: Optional[str] = Field(
default=None,
description="Host binding forwarded to embedding-atlas (defaults to settings.container_host).",
)
port: Optional[int] = Field(
default=None,
description="Preferred host port for the embedding-atlas UI.",
)
auto_remove_seconds: Optional[int] = Field(
default=None,
description="Lifetime of the session before automatic cleanup triggers.",
)
def normalised_session_id(self) -> str:
return self.session_id or uuid4().hex
class SessionRecord(BaseModel):
"""Internal representation of a live session."""
session_id: str
container_id: str
container_name: str
port: int
host: str
started_at: datetime
expires_at: datetime
dataset_path: str
class SessionResponse(BaseModel):
"""HTTP response returned after a session is launched."""
session_id: str
container_name: str
endpoint: str
port: int
host: str
expires_at: datetime
class SessionCloseResponse(BaseModel):
"""HTTP response returned after a session is terminated."""
session_id: str
removed: bool
class SessionListResponse(BaseModel):
"""Shape returned when listing active sessions."""
sessions: List[SessionRecord]
MANAGED_LABEL = "embedding-backend.managed"
EXPIRY_LABEL = "embedding-backend.expires-at"
SESSION_ID_LABEL = "embedding-backend.session-id"
DATASET_LABEL = "embedding-backend.dataset-path"
def default_labels(session_id: str, expires_at: datetime, dataset_path: str) -> Dict[str, str]:
return {
MANAGED_LABEL: "true",
SESSION_ID_LABEL: session_id,
EXPIRY_LABEL: str(int(expires_at.timestamp())),
DATASET_LABEL: dataset_path,
}
__all__ = [
"SessionCreateRequest",
"SessionRecord",
"SessionResponse",
"SessionCloseResponse",
"SessionListResponse",
"default_labels",
"MANAGED_LABEL",
"EXPIRY_LABEL",
"SESSION_ID_LABEL",
"DATASET_LABEL",
]

View File

@@ -0,0 +1,71 @@
"""FastAPI routes exposing the session management API."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from .docker_manager import (
DatasetDownloadError,
PortUnavailableError,
SessionExistsError,
SessionManager,
SessionNotFoundError,
)
from .models import (
SessionCloseResponse,
SessionCreateRequest,
SessionListResponse,
SessionRecord,
SessionResponse,
)
def get_router(manager: SessionManager) -> APIRouter:
router = APIRouter(prefix="/sessions", tags=["sessions"])
async def get_manager() -> SessionManager:
return manager
@router.post("/", response_model=SessionResponse, status_code=status.HTTP_201_CREATED)
async def create_session(
payload: SessionCreateRequest,
manager: SessionManager = Depends(get_manager),
) -> SessionResponse:
try:
record = await manager.create_session(payload)
except SessionExistsError as exc:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
except DatasetDownloadError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
except PortUnavailableError as exc:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
endpoint = f"http://{record.host}:{record.port}"
return SessionResponse(
session_id=record.session_id,
container_name=record.container_name,
endpoint=endpoint,
port=record.port,
host=record.host,
expires_at=record.expires_at,
)
@router.delete("/{session_id}", response_model=SessionCloseResponse)
async def close_session(
session_id: str,
manager: SessionManager = Depends(get_manager),
) -> SessionCloseResponse:
try:
await manager.close_session(session_id)
except SessionNotFoundError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
return SessionCloseResponse(session_id=session_id, removed=True)
@router.get("/", response_model=SessionListResponse)
async def list_sessions(manager: SessionManager = Depends(get_manager)) -> SessionListResponse:
sessions = await manager.list_sessions()
return SessionListResponse(sessions=list(sessions))
return router
__all__ = ["get_router"]

3394
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff