update
This commit is contained in:
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# SCM syntax highlighting & preventing 3-way merges
|
||||
pixi.lock merge=binary linguist-language=YAML linguist-generated=true
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,2 +1,7 @@
|
||||
.venv/
|
||||
.DS_Store
|
||||
__pycache__/
|
||||
*.pyc# pixi environments
|
||||
.pixi/*
|
||||
!.pixi/config.toml
|
||||
*.egg-info/
|
||||
115
README.md
115
README.md
@@ -33,6 +33,100 @@ export HF_HUB_OFFLINE=1
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
```
|
||||
|
||||
## 会话编排服务(FastAPI / MCP)
|
||||
|
||||
使用 `uv run embedding-backend-api` 可以启动一个同时兼容 FastAPI 与 FastMCP 的后端服务。该服务监听 `/sessions` 路径,负责按需拉起 `embedding-atlas` 容器并在 10 小时后自动清理。
|
||||
|
||||
```bash
|
||||
uv run embedding-backend-api
|
||||
# 或以 MCP 模式启动(stdio)
|
||||
uv run embedding-backend-mcp
|
||||
```
|
||||
|
||||
### REST API 用法
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:9000/sessions \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"data_url": "https://example.com/data.csv",
|
||||
"extra_args": ["--text", "smiles"]
|
||||
}'
|
||||
|
||||
# 关闭会话
|
||||
curl -X DELETE http://localhost:9000/sessions/<session_id>
|
||||
|
||||
# 查看当前会话
|
||||
curl http://localhost:9000/sessions
|
||||
```
|
||||
|
||||
请求体支持 `session_id`、`port`、`auto_remove_seconds`、`environment` 等字段;省略 `session_id` 时会自动生成并在响应中返回。返回结果包含容器名称与可访问的前端地址,供 FastAPI 或 MCP 客户端转发给前端使用。
|
||||
|
||||
如果希望以交互方式测试 API,可在浏览器访问 `http://localhost:9000/docs` 打开自动生成的 Swagger UI。
|
||||
|
||||
#### 请求体字段说明
|
||||
|
||||
- `session_id`:可选,自定义的会话 ID;省略时系统自动生成。
|
||||
- `data_url`:必填,指向 CSV/Parquet 的 HTTPS 链接,会被下载到 orchestrator 与 DinD 共享的 `/sessions/<session_id>/` 中。
|
||||
- `input_filename`:可选,保存到共享卷时使用的文件名;默认根据 URL 推断。
|
||||
- `extra_args`:可选,附加给 `embedding-atlas` CLI 的参数数组,例如 `["--text", "smiles"]`。
|
||||
- `environment`:可选,注入到容器内的环境变量映射。
|
||||
- `labels`:可选,附加到容器上的 Docker labels(在默认的 `embedding-backend.*` 标签之外),便于自定义监控或追踪。
|
||||
- `image`:可选,覆盖默认的 `embedding-atlas` 镜像名。
|
||||
- `host`:可选,传递给 `embedding-atlas --host` 的值,默认 `0.0.0.0`。
|
||||
- `port`:可选,显式指定宿主机端口;若缺省会自动在配置的区间内分配。
|
||||
- `auto_remove_seconds`:可选,覆盖默认的 10 小时存活时间。
|
||||
|
||||
示例请求:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:9000/sessions \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"session_id": "demo-session",
|
||||
"data_url": "https://example.com/data.csv",
|
||||
"input_filename": "data.csv",
|
||||
"extra_args": ["--text", "smiles"],
|
||||
"environment": {"HF_ENDPOINT": "https://hf-mirror.com"},
|
||||
"labels": {"team": "chem"},
|
||||
"auto_remove_seconds": 7200
|
||||
}'
|
||||
```
|
||||
|
||||
#### GET /sessions 响应示例
|
||||
|
||||
无查询参数,返回当前所有会话的列表:
|
||||
|
||||
```json
|
||||
{
|
||||
"sessions": [
|
||||
{
|
||||
"session_id": "demo-session",
|
||||
"container_id": "e556c0f1c35b...",
|
||||
"container_name": "embedding-atlas_demo-session",
|
||||
"port": 6000,
|
||||
"host": "0.0.0.0",
|
||||
"started_at": "2025-09-22T14:37:25.206038Z",
|
||||
"expires_at": "2025-09-23T00:37:24.416241Z",
|
||||
"dataset_path": "/sessions/demo-session/data.csv"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
- `sessions` 数组为空时表示当前没有存活的容器。
|
||||
- `dataset_path` 指向共享卷内对应数据集的绝对路径。
|
||||
|
||||
### MCP 集成
|
||||
|
||||
根目录的 `fastmcp.json` 示例可直接将本项目注册为 MCP 工具:
|
||||
|
||||
```bash
|
||||
uv run embedding-backend-mcp
|
||||
```
|
||||
|
||||
FastMCP 客户端加载该配置后,可用标准 MCP 协议转发同一套 REST 接口,从而与传统后端保持一致的行为。
|
||||
|
||||
## 命令行生成嵌入可视化交互
|
||||
|
||||
```bash
|
||||
@@ -79,3 +173,24 @@ uv run python script/merge_splits.py --input-dir splits_v2/ --output data/drugba
|
||||
```bash
|
||||
uv run embedding-atlas data/drugbank_split_merge.csv --text smiles
|
||||
```
|
||||
|
||||
## 容器化部署
|
||||
|
||||
项目提供 `docker/` 目录用于快速启动后端:
|
||||
|
||||
1. 先构建可被 orchestrator 复用的 `embedding-atlas` 镜像:
|
||||
```bash
|
||||
docker build -f docker/embedding-atlas.Dockerfile -t embedding-atlas:latest .
|
||||
```
|
||||
2. 启动 DIND + orchestrator 组合:
|
||||
```bash
|
||||
docker compose -f docker/docker-compose.yml up --build
|
||||
```
|
||||
- `engine` 服务运行 `docker:dind`,对外暴露 `tcp://localhost:2375` 供 orchestrator 通过 socket 管理容器;
|
||||
- 两个服务通过 `sessions-data` 卷共享 `/sessions` 目录,后端会把下载的数据放到这里;
|
||||
- `orchestrator` 服务运行 FastAPI/FastMCP 后端,默认开放 `http://localhost:9000`;
|
||||
- 会话缓存保存在 `sessions-data` 卷,可在容器重启时保留下载的数据集。
|
||||
|
||||
如需调整默认镜像或端口,可在 `docker/docker-compose.yml` 中覆盖 `EMBEDDING_*` 环境变量,或在部署时通过 `.env` 文件注入。
|
||||
|
||||
> 注意:由于 orchestrator 运行在容器内并通过 DinD 调度新容器,如需在宿主机直接访问 `embedding-atlas` 的 Web UI,需要确保相应端口从 `docker_engine_1` 转发到宿主,可按需求使用 `podman port` 或额外的反向代理。
|
||||
|
||||
90
data/ring12_20/README.md
Normal file
90
data/ring12_20/README.md
Normal file
@@ -0,0 +1,90 @@
|
||||
Your Filtered Macrolactone Database
|
||||
|
||||
11036 compounds have been filtered from MacrolactoneDB based on your specified inputs.
|
||||
|
||||
```bash
|
||||
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text ecfp4_binary
|
||||
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text tanimoto_top_neighbors
|
||||
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text smiles
|
||||
```
|
||||
|
||||
## 嵌入和投影优化
|
||||
|
||||
### projection_x 和 projection_y 的生成过程
|
||||
|
||||
UMAP 降维计算
|
||||
这两个坐标是通过 _run_umap() 函数生成的,该函数使用 UMAP 算法将高维嵌入向量降维到 2D 空间 projection.py:64-88 。
|
||||
|
||||
具体流程如下:
|
||||
|
||||
计算最近邻 - 首先使用 nearest_neighbors() 计算每个点的 k 个最近邻 projection.py:76-83
|
||||
UMAP 投影 - 然后使用预计算的邻居信息进行 UMAP 降维 projection.py:85-86
|
||||
坐标分配 - 结果的第一列成为 projection_x,第二列成为 projection_y projection.py:259-260
|
||||
默认参数设置
|
||||
UMAP 算法使用以下默认参数:
|
||||
|
||||
邻居数量: 15 个最近邻 projection.py:74
|
||||
距离度量: cosine 距离 projection.py:73
|
||||
在不同数据类型中的应用
|
||||
文本数据处理
|
||||
对于您的 SMILES 分子数据,系统首先使用 SentenceTransformers 生成文本嵌入,然后通过 UMAP 降维 projection.py:251-260 。
|
||||
|
||||
预计算向量处理
|
||||
如果您有预计算的 ECFP4 向量,系统会直接对这些向量进行 UMAP 降维 projection.py:311-318 。
|
||||
|
||||
可视化中的作用
|
||||
在前端可视化界面中,这些坐标用作:
|
||||
|
||||
散点图的 X/Y 轴 - 每个数据点在 2D 空间中的位置
|
||||
颜色编码的基础 - 可以根据坐标值进行颜色映射 embedding-atlas.md:68-70
|
||||
演示数据示例
|
||||
在项目的演示数据生成中,可以看到相同的处理流程:使用 SentenceTransformers 计算嵌入,然后通过 UMAP 生成 projection_x 和 projection_y 坐标 generate_demo_data.py:42-43 。
|
||||
|
||||
Notes
|
||||
|
||||
这些投影坐标的质量很大程度上取决于原始嵌入的质量和 UMAP 参数的选择。对于化学分子数据,使用专门的分子嵌入模型通常会产生更有意义的 2D 投影,其中化学结构相似的分子会在投影空间中聚集在一起。
|
||||
|
||||
### UMAP 参数调优
|
||||
|
||||
您可以通过调整 UMAP 参数来获得更好的可视化效果:
|
||||
|
||||
```bash
|
||||
# 调整邻居数量和距离参数
|
||||
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text smiles \
|
||||
--umap-n-neighbors 30 \
|
||||
--umap-min-dist 0.1 \
|
||||
--umap-metric cosine \
|
||||
--umap-random-state 42
|
||||
```
|
||||
|
||||
## 自定义嵌入模型
|
||||
|
||||
对于化学分子数据,您可能想使用专门的模型:
|
||||
|
||||
并且符合:
|
||||
|
||||
模型支持范围
|
||||
embedding-atlas 支持两种类型的自定义模型:
|
||||
|
||||
文本嵌入模型
|
||||
对于文本数据(如您的 SMILES 分子数据),系统使用 SentenceTransformers 库 projection.py:118-126 。这意味着您可以使用任何与 SentenceTransformers 兼容的 Hugging Face 模型。
|
||||
|
||||
图像嵌入模型
|
||||
对于图像数据,系统使用 transformers 库的 pipeline 功能 projection.py:168-180 。
|
||||
|
||||
模型格式要求
|
||||
SentenceTransformers 兼容性
|
||||
文本模型必须与 SentenceTransformers 库兼容 projection.py:98-99 。这包括:
|
||||
|
||||
专门训练用于句子嵌入的模型
|
||||
支持 .encode() 方法的模型
|
||||
能够输出固定维度向量的模型
|
||||
|
||||
|
||||
```bash
|
||||
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text smiles \
|
||||
--umap-n-neighbors 30 \
|
||||
--umap-min-dist 0.1 \
|
||||
--umap-metric cosine \
|
||||
--umap-random-state 42
|
||||
```
|
||||
89
data/ring12_20/counts.txt
Normal file
89
data/ring12_20/counts.txt
Normal file
@@ -0,0 +1,89 @@
|
||||
Target Organisms
|
||||
Homo sapiens 815
|
||||
Homo sapiens, None 180
|
||||
Plasmodium falciparum 161
|
||||
Hepatitis C virus, None 112
|
||||
Homo sapiens, Plasmodium falciparum 63
|
||||
Oryctolagus cuniculus 62
|
||||
Mus musculus 60
|
||||
Toxoplasma gondii 39
|
||||
Homo sapiens, Rattus norvegicus 27
|
||||
Mus musculus, Homo sapiens 24
|
||||
None, Rattus norvegicus 23
|
||||
Human immunodeficiency virus 1 20
|
||||
Hepatitis C virus 18
|
||||
Rattus norvegicus 17
|
||||
Homo sapiens, Sus scrofa 11
|
||||
Homo sapiens, Chlorocebus aethiops 10
|
||||
Serratia marcescens 9
|
||||
Escherichia coli 8
|
||||
Oryctolagus cuniculus, Homo sapiens 7
|
||||
Streptococcus pneumoniae 6
|
||||
Oryctolagus cuniculus, Staphylococcus aureus, Raoultella planticola, Bacillus subtilis, Mus musculus, Micrococcus luteus, None, Escherichia coli, Plasmodium falciparum, Streptococcus pneumoniae, Homo sapiens, Escherichia coli K-12, Toxoplasma gondii 6
|
||||
Plasmodium falciparum K1 5
|
||||
Bacillus anthracis 5
|
||||
Mus musculus, Homo sapiens, None 5
|
||||
Bacillus anthracis, Homo sapiens 4
|
||||
Candida albicans, Cryptococcus neoformans, Aspergillus fumigatus 4
|
||||
Mus musculus, None 4
|
||||
Plasmodium falciparum, Homo sapiens, None 4
|
||||
None, Homo sapiens, Plasmodium falciparum 3
|
||||
Bacillus subtilis, Homo sapiens 3
|
||||
Oryctolagus cuniculus, Homo sapiens, None 3
|
||||
Sus scrofa, Mus musculus, None, Plasmodium falciparum, Homo sapiens, Rattus norvegicus 2
|
||||
Homo sapiens, None, Rattus norvegicus 2
|
||||
Cryptococcus neoformans 2
|
||||
Homo sapiens, None, Chlorocebus aethiops 2
|
||||
Staphylococcus aureus 2
|
||||
Candida albicans, Cryptococcus neoformans, Mycobacterium intracellulare, Aspergillus fumigatus 2
|
||||
Mus musculus, None, Human immunodeficiency virus 1 2
|
||||
Escherichia coli (strain K12) 2
|
||||
Plasmodium falciparum 3D7, Homo sapiens 2
|
||||
Aspergillus fumigatus 1
|
||||
Sus scrofa 1
|
||||
Saccharomyces cerevisiae S288c, Human immunodeficiency virus 1, Human herpesvirus 1, Plasmodium falciparum, None, Homo sapiens, Rattus norvegicus 1
|
||||
Hepatitis C virus, Homo sapiens, None 1
|
||||
Plasmodium falciparum 3D7 1
|
||||
Bacillus subtilis 1
|
||||
Mus musculus, Homo sapiens, None, Saccharomyces cerevisiae 1
|
||||
Chlorocebus aethiops 1
|
||||
Homo sapiens, Escherichia coli K-12, None 1
|
||||
Hepatitis C virus, Homo sapiens, None, Rattus norvegicus 1
|
||||
None, Homo sapiens, Human herpesvirus 1 1
|
||||
Homo sapiens, None, Trypanosoma brucei brucei 1
|
||||
Homo sapiens, None, Cryptococcus neoformans 1
|
||||
Homo sapiens, Rattus norvegicus, Human immunodeficiency virus 1 1
|
||||
None, Plasmodium falciparum, Escherichia coli, Streptococcus pneumoniae, Naegleria fowleri, Homo sapiens, Streptococcus, Toxoplasma gondii 1
|
||||
Giardia intestinalis, Trypanosoma cruzi, Equus caballus, Bos taurus, Mus musculus, None, Plasmodium falciparum, Chlorocebus aethiops, Homo sapiens 1
|
||||
Plasmodium falciparum NF54, Trypanosoma cruzi, Trypanosoma brucei rhodesiense, Rattus norvegicus 1
|
||||
None, Homo sapiens, Plasmodium falciparum K1, Plasmodium falciparum 1
|
||||
Saccharomyces cerevisiae S288c, Homo sapiens, None, Saccharomyces cerevisiae, Phytophthora sojae 1
|
||||
Bacillus subtilis, Homo sapiens, Schistosoma mansoni, Saccharomyces cerevisiae, Giardia intestinalis 1
|
||||
Streptococcus, Homo sapiens, None 1
|
||||
Mus musculus, Homo sapiens, Rattus norvegicus 1
|
||||
Homo sapiens, Spinacia oleracea 1
|
||||
Human immunodeficiency virus 1, Mus musculus, None, Hepatitis C virus, Homo sapiens, Rattus norvegicus 1
|
||||
None, Plasmodium falciparum, Trypanosoma brucei rhodesiense 1
|
||||
Hepatitis C virus, None, Rattus norvegicus 1
|
||||
Homo sapiens, Equus caballus 1
|
||||
Plasmodium falciparum NF54, Trypanosoma cruzi, Trypanosoma brucei rhodesiense 1
|
||||
Schistosoma mansoni, Influenza A virus 1
|
||||
Leishmania chagasi, Trypanosoma cruzi 1
|
||||
Candida albicans, Cryptococcus neoformans 1
|
||||
None, Plasmodium falciparum 1
|
||||
Caenorhabditis elegans 1
|
||||
Bos taurus, Sus scrofa 1
|
||||
Plasmodium falciparum, Enterococcus faecium 1
|
||||
Homo sapiens, Gallus gallus 1
|
||||
Homo sapiens, Escherichia coli 1
|
||||
Plasmodium falciparum, Homo sapiens, None, Rattus norvegicus, Schistosoma mansoni 1
|
||||
Homo sapiens, None, Influenza A virus 1
|
||||
Mycobacterium tuberculosis, None 1
|
||||
Escherichia coli, Homo sapiens, Toxoplasma gondii, None, Streptococcus pneumoniae 1
|
||||
Bacillus subtilis, Oryctolagus cuniculus, Homo sapiens, Schistosoma mansoni, Giardia intestinalis 1
|
||||
Homo sapiens, None, Rattus norvegicus, Escherichia coli O157:H7 1
|
||||
Giardia intestinalis, Schistosoma mansoni, Mus musculus, None, Homo sapiens, Saccharomyces cerevisiae 1
|
||||
Trypanosoma cruzi 1
|
||||
Influenza A virus 1
|
||||
Escherichia coli K-12 1
|
||||
Human herpesvirus 4 (strain B95-8) 1
|
||||
11037
data/ring12_20/temp.csv
Normal file
11037
data/ring12_20/temp.csv
Normal file
File diff suppressed because one or more lines are too long
11037
data/ring12_20/temp_with_macrocycles.csv
Normal file
11037
data/ring12_20/temp_with_macrocycles.csv
Normal file
File diff suppressed because one or more lines are too long
11037
data/ring12_20/temp_with_macrocycles_with_ecfp4.csv
Normal file
11037
data/ring12_20/temp_with_macrocycles_with_ecfp4.csv
Normal file
File diff suppressed because one or more lines are too long
BIN
data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet
Normal file
BIN
data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet
Normal file
Binary file not shown.
41
docker/Dockerfile
Normal file
41
docker/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
FROM python:3.12-slim AS base
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
curl \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --no-cache-dir uv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV UV_PROJECT_ENVIRONMENT=/app/.venv
|
||||
ENV UV_PIP_INDEX_URL=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
ENV HF_ENDPOINT=https://hf-mirror.com
|
||||
ENV HF_HOME=/app/.cache/huggingface
|
||||
ENV TRANSFORMERS_CACHE=/app/.cache/huggingface/transformers
|
||||
ENV HF_DATASETS_CACHE=/app/.cache/huggingface/datasets
|
||||
ENV EMBEDDING_ATLAS_DEVICE=cpu
|
||||
|
||||
RUN mkdir -p "$HF_HOME" "$TRANSFORMERS_CACHE" "$HF_DATASETS_CACHE"
|
||||
VOLUME ["/app/.cache/huggingface"]
|
||||
|
||||
COPY pyproject.toml README.md ./
|
||||
COPY src ./src
|
||||
COPY script ./script
|
||||
|
||||
# Install dependencies with uv. When a lock file is present it will be used automatically.
|
||||
RUN uv sync --no-dev || uv sync --no-dev
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV EMBEDDING_API_HOST=0.0.0.0
|
||||
ENV EMBEDDING_API_PORT=9000
|
||||
|
||||
EXPOSE 9000
|
||||
|
||||
39
docker/docker-compose.yml
Normal file
39
docker/docker-compose.yml
Normal file
@@ -0,0 +1,39 @@
|
||||
version: "3.9"
|
||||
|
||||
services:
|
||||
engine:
|
||||
image: docker:25.0-dind
|
||||
privileged: true
|
||||
environment:
|
||||
- DOCKER_TLS_CERTDIR=
|
||||
command: ["--host=tcp://0.0.0.0:2375"]
|
||||
ports:
|
||||
- "2375:2375"
|
||||
volumes:
|
||||
- engine-data:/var/lib/docker
|
||||
- sessions-data:/sessions
|
||||
restart: unless-stopped
|
||||
|
||||
orchestrator:
|
||||
# Ensure the embedding-atlas image exists (e.g. docker build -f docker/embedding-atlas.Dockerfile -t embedding-atlas:latest ..)
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile
|
||||
depends_on:
|
||||
- engine
|
||||
environment:
|
||||
- EMBEDDING_DOCKER_URL=tcp://engine:2375
|
||||
- EMBEDDING_CONTAINER_IMAGE=embedding-atlas:latest
|
||||
- EMBEDDING_CONTAINER_NAME_PREFIX=embedding-atlas
|
||||
- EMBEDDING_API_HOST=0.0.0.0
|
||||
- EMBEDDING_API_PORT=9000
|
||||
- EMBEDDING_SESSION_ROOT=/sessions
|
||||
volumes:
|
||||
- sessions-data:/sessions
|
||||
ports:
|
||||
- "9000:9000"
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
engine-data:
|
||||
sessions-data:
|
||||
17
docker/embedding-atlas.Dockerfile
Normal file
17
docker/embedding-atlas.Dockerfile
Normal file
@@ -0,0 +1,17 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --no-cache-dir embedding-atlas
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
EXPOSE 5055
|
||||
|
||||
CMD ["embedding-atlas"]
|
||||
11
fastmcp.json
Normal file
11
fastmcp.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"embedding-atlas-session-manager": {
|
||||
"command": "embedding-backend-mcp",
|
||||
"args": [],
|
||||
"env": {},
|
||||
"timeout": 30000,
|
||||
"description": "Start the embedding orchestrator in MCP stdio mode"
|
||||
}
|
||||
}
|
||||
}
|
||||
10
pixi.toml
Normal file
10
pixi.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[workspace]
|
||||
authors = ["lingyuzeng <pylyzeng@gmail.com>"]
|
||||
channels = ["conda-forge"]
|
||||
name = "embedding_atlas"
|
||||
platforms = ["osx-arm64"]
|
||||
version = "0.1.0"
|
||||
|
||||
[tasks]
|
||||
|
||||
[dependencies]
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
[project]
|
||||
name = "mole-embedding-atlas"
|
||||
version = "0.1.0"
|
||||
@@ -12,4 +13,31 @@ dependencies = [
|
||||
"rdkit",
|
||||
"pandas",
|
||||
"selfies==2.1.1",
|
||||
"fastapi>=0.111",
|
||||
"uvicorn[standard]>=0.29",
|
||||
"fastmcp>=2.11",
|
||||
"docker>=7.1",
|
||||
"httpx>=0.27",
|
||||
"pydantic-settings>=2.2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
embedding-backend = "embedding_backend.main:main"
|
||||
embedding-backend-api = "embedding_backend.main:run_api"
|
||||
embedding-backend-mcp = "embedding_backend.main:run_mcp_stdio"
|
||||
|
||||
[tool.uv.pip]
|
||||
index-url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"" = "src"}
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=68", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.uv]
|
||||
package = true
|
||||
|
||||
186
script/add_ecfp4_tanimoto.py
Normal file
186
script/add_ecfp4_tanimoto.py
Normal file
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Augment a CSV with ECFP4 binary fingerprints and Tanimoto neighbor summaries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from rdkit import Chem, DataStructs
|
||||
from rdkit.Chem import rdFingerprintGenerator
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECFP4Generator:
|
||||
"""Generate ECFP4 (Morgan radius 2) fingerprints as RDKit bit vectors."""
|
||||
|
||||
n_bits: int = 2048
|
||||
radius: int = 2
|
||||
include_chirality: bool = True
|
||||
generator: rdFingerprintGenerator.MorganGenerator = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.generator = rdFingerprintGenerator.GetMorganGenerator(
|
||||
radius=self.radius,
|
||||
fpSize=self.n_bits,
|
||||
includeChirality=self.include_chirality,
|
||||
)
|
||||
|
||||
def fingerprint(self, smiles: str) -> Optional[DataStructs.ExplicitBitVect]:
|
||||
if not smiles:
|
||||
return None
|
||||
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return self.generator.GetFingerprint(mol)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def to_binary_string(self, fp: DataStructs.ExplicitBitVect) -> str:
|
||||
arr = np.zeros((self.n_bits,), dtype=np.uint8)
|
||||
DataStructs.ConvertToNumpyArray(fp, arr)
|
||||
return ''.join(arr.astype(str))
|
||||
|
||||
|
||||
def tanimoto_top_k(
|
||||
fingerprints: Sequence[Optional[DataStructs.ExplicitBitVect]],
|
||||
ids: Sequence[str],
|
||||
top_k: int = 5,
|
||||
) -> List[str]:
|
||||
"""Return semicolon-delimited Tanimoto summaries for each fingerprint."""
|
||||
valid_indices = [idx for idx, fp in enumerate(fingerprints) if fp is not None]
|
||||
valid_fps = [fingerprints[idx] for idx in valid_indices]
|
||||
|
||||
index_lookup = {pos: original for pos, original in enumerate(valid_indices)}
|
||||
summaries = [''] * len(fingerprints)
|
||||
|
||||
if not valid_fps:
|
||||
return summaries
|
||||
|
||||
for pos, fp in enumerate(valid_fps):
|
||||
sims = DataStructs.BulkTanimotoSimilarity(fp, valid_fps)
|
||||
ranked: List[Tuple[int, float]] = []
|
||||
for other_pos, score in enumerate(sims):
|
||||
if other_pos == pos:
|
||||
continue
|
||||
ranked.append((other_pos, score))
|
||||
|
||||
ranked.sort(key=lambda item: item[1], reverse=True)
|
||||
top_entries = []
|
||||
for other_pos, score in ranked[:top_k]:
|
||||
original_idx = index_lookup[other_pos]
|
||||
top_entries.append(f"{ids[original_idx]}:{score:.4f}")
|
||||
|
||||
summaries[index_lookup[pos]] = ';'.join(top_entries)
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("input_csv", type=pathlib.Path, help="Source CSV with SMILES")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=pathlib.Path,
|
||||
default=None,
|
||||
help="Destination file (default: <input>_with_ecfp4.parquet)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smiles-column",
|
||||
default="smiles",
|
||||
help="Name of the column containing SMILES (default: smiles)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--id-column",
|
||||
default="generated_id",
|
||||
help="Column used to label Tanimoto neighbors (default: generated_id)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of nearest neighbors to report in Tanimoto summaries (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
choices=("parquet", "csv", "auto"),
|
||||
default="parquet",
|
||||
help="Output format; defaults to parquet unless overridden or inferred from --output",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if not args.input_csv.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
|
||||
|
||||
def resolve_output_path() -> pathlib.Path:
|
||||
if args.output is not None:
|
||||
return args.output
|
||||
|
||||
suffix = ".parquet" if args.format in ("parquet", "auto") else ".csv"
|
||||
return args.input_csv.with_name(f"{args.input_csv.stem}_with_ecfp4{suffix}")
|
||||
|
||||
def resolve_format(path: pathlib.Path) -> str:
|
||||
suffix = path.suffix.lower()
|
||||
if suffix in {".parquet", ".pq"}:
|
||||
return "parquet"
|
||||
if suffix == ".csv":
|
||||
return "csv"
|
||||
if args.format == "parquet":
|
||||
return "parquet"
|
||||
if args.format == "csv":
|
||||
return "csv"
|
||||
raise ValueError(
|
||||
"无法根据输出文件推断格式,请为 --output 指定 .parquet/.csv 后缀或使用 --format",
|
||||
)
|
||||
|
||||
output_path = resolve_output_path()
|
||||
output_format = resolve_format(output_path)
|
||||
|
||||
if args.input_csv.resolve() == output_path.resolve():
|
||||
raise ValueError("Output path must differ from input path to avoid overwriting input.")
|
||||
|
||||
df = pd.read_csv(args.input_csv)
|
||||
if args.smiles_column not in df.columns:
|
||||
raise ValueError(f"Column '{args.smiles_column}' not found in input data")
|
||||
|
||||
smiles_series = df[args.smiles_column].fillna('')
|
||||
|
||||
if args.id_column in df.columns:
|
||||
ids = df[args.id_column].astype(str).tolist()
|
||||
else:
|
||||
ids = [f"D{idx:06d}" for idx in range(1, len(df) + 1)]
|
||||
df[args.id_column] = ids
|
||||
|
||||
generator = ECFP4Generator()
|
||||
fingerprints: List[Optional[DataStructs.ExplicitBitVect]] = []
|
||||
binary_repr: List[str] = []
|
||||
|
||||
for smiles in smiles_series:
|
||||
fp = generator.fingerprint(smiles)
|
||||
fingerprints.append(fp)
|
||||
binary_repr.append(generator.to_binary_string(fp) if fp is not None else '')
|
||||
|
||||
df['ecfp4_binary'] = binary_repr
|
||||
df['tanimoto_top_neighbors'] = tanimoto_top_k(fingerprints, ids, top_k=args.top_k)
|
||||
|
||||
if output_format == "parquet":
|
||||
df.to_parquet(output_path, index=False)
|
||||
else:
|
||||
df.to_csv(output_path, index=False)
|
||||
|
||||
print(f"Wrote augmented data with {len(df)} rows to {output_path} ({output_format})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
127
script/add_macrocycle_columns.py
Normal file
127
script/add_macrocycle_columns.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Add generated IDs and macrolactone ring size annotations to a CSV file."""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from rdkit import Chem
|
||||
|
||||
|
||||
@dataclass
|
||||
class MacrolactoneRingDetector:
|
||||
"""Detect macrolactone rings in SMILES strings via SMARTS patterns."""
|
||||
|
||||
min_size: int = 12
|
||||
max_size: int = 20
|
||||
patterns: Dict[int, Optional[Chem.Mol]] = field(init=False, default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.patterns = {
|
||||
size: Chem.MolFromSmarts(f"[r{size}]([#8][#6](=[#8]))")
|
||||
for size in range(self.min_size, self.max_size + 1)
|
||||
}
|
||||
|
||||
def ring_sizes(self, smiles: str) -> List[int]:
|
||||
"""Return a sorted list of macrolactone ring sizes present in the SMILES."""
|
||||
if not smiles:
|
||||
return []
|
||||
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return []
|
||||
|
||||
ring_atoms = mol.GetRingInfo().AtomRings()
|
||||
if not ring_atoms:
|
||||
return []
|
||||
|
||||
matched_rings: Set[Tuple[int, ...]] = set()
|
||||
matched_sizes: Set[int] = set()
|
||||
|
||||
for size, query in self.patterns.items():
|
||||
if query is None:
|
||||
continue
|
||||
|
||||
for match in mol.GetSubstructMatches(query, uniquify=True):
|
||||
ring = self._pick_ring(match, ring_atoms, size)
|
||||
if ring and ring not in matched_rings:
|
||||
matched_rings.add(ring)
|
||||
matched_sizes.add(size)
|
||||
|
||||
sizes = sorted(matched_sizes)
|
||||
|
||||
if len(sizes) > 1:
|
||||
warnings.warn(
|
||||
"Multiple macrolactone ring sizes detected",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
print(f"Multiple ring sizes {sizes} for SMILES: {smiles}")
|
||||
|
||||
return sizes
|
||||
|
||||
@staticmethod
|
||||
def _pick_ring(
|
||||
match: Tuple[int, ...], rings: Iterable[Tuple[int, ...]], expected_size: int
|
||||
) -> Optional[Tuple[int, ...]]:
|
||||
ring_atom = match[0]
|
||||
for ring in rings:
|
||||
if len(ring) == expected_size and ring_atom in ring:
|
||||
return tuple(sorted(ring))
|
||||
return None
|
||||
|
||||
|
||||
def add_columns(df: pd.DataFrame, detector: MacrolactoneRingDetector) -> pd.DataFrame:
|
||||
result = df.copy()
|
||||
result["generated_id"] = [f"D{index:06d}" for index in range(1, len(result) + 1)]
|
||||
|
||||
smiles_series = result["smiles"].fillna("") if "smiles" in result.columns else pd.Series(
|
||||
[""] * len(result), index=result.index
|
||||
)
|
||||
|
||||
def format_sizes(smiles: str) -> str:
|
||||
sizes = detector.ring_sizes(smiles)
|
||||
return ";".join(str(size) for size in sizes) if sizes else ""
|
||||
|
||||
result["macrocycle_ring_sizes"] = smiles_series.apply(format_sizes)
|
||||
return result
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("input_csv", type=pathlib.Path, help="Path to the source CSV file")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=pathlib.Path,
|
||||
default=None,
|
||||
help="Destination for the augmented CSV (default: <input>_with_macrocycles.csv)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if not args.input_csv.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
|
||||
|
||||
output_path = args.output or args.input_csv.with_name(
|
||||
f"{args.input_csv.stem}_with_macrocycles.csv"
|
||||
)
|
||||
|
||||
if args.input_csv.resolve() == output_path.resolve():
|
||||
raise ValueError("Output path must differ from input path to avoid overwriting.")
|
||||
|
||||
df = pd.read_csv(args.input_csv)
|
||||
detector = MacrolactoneRingDetector()
|
||||
augmented = add_columns(df, detector)
|
||||
augmented.to_csv(output_path, index=False)
|
||||
|
||||
print(f"Wrote augmented data with {len(augmented)} rows to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
439
script/ecfp4_umap_embedding_optimized.py
Normal file
439
script/ecfp4_umap_embedding_optimized.py
Normal file
@@ -0,0 +1,439 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Optimized ECFP4 Fingerprinting with UMAP Visualization for Macrolactone Molecules
|
||||
|
||||
This script processes SMILES data to:
|
||||
1. Generate ECFP4 fingerprints using RDKit
|
||||
2. Detect ring numbers in macrolactone molecules using SMARTS patterns
|
||||
3. Generate unique IDs for molecules without existing IDs
|
||||
4. Perform UMAP dimensionality reduction with Tanimoto distance
|
||||
5. Prepare data for embedding-atlas visualization
|
||||
|
||||
Optimized for large datasets with progress tracking and memory efficiency.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import subprocess
|
||||
from typing import Optional, List
|
||||
|
||||
# RDKit imports
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import rdMolDescriptors, DataStructs
|
||||
from rdkit.Chem.MolStandardize import rdMolStandardize
|
||||
|
||||
# Data processing
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# UMAP and visualization
|
||||
import umap
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Suppress warnings
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# Progress bar
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
HAS_TQDM = True
|
||||
except ImportError:
|
||||
HAS_TQDM = False
|
||||
|
||||
class MacrolactoneProcessor:
|
||||
"""Process macrolactone molecules for embedding visualization."""
|
||||
|
||||
def __init__(self, n_bits: int = 2048, radius: int = 2, chirality: bool = True):
|
||||
"""
|
||||
Initialize processor with ECFP4 parameters.
|
||||
|
||||
Args:
|
||||
n_bits: Number of fingerprint bits (default: 2048)
|
||||
radius: Morgan fingerprint radius (default: 2 for ECFP4)
|
||||
chirality: Include chirality information (default: True)
|
||||
"""
|
||||
self.n_bits = n_bits
|
||||
self.radius = radius
|
||||
self.chirality = chirality
|
||||
|
||||
# Standardizer for molecule preprocessing
|
||||
self.standardizer = rdMolStandardize.MetalDisconnector()
|
||||
|
||||
# SMARTS patterns for different ring sizes (12-20 membered rings)
|
||||
self.ring_smarts = {
|
||||
12: '[r12][#8][#6](=[#8])', # 12-membered ring with lactone
|
||||
13: '[r13][#8][#6](=[#8])', # 13-membered ring with lactone
|
||||
14: '[r14][#8][#6](=[#8])', # 14-membered ring with lactone
|
||||
15: '[r15][#8][#6](=[#8])', # 15-membered ring with lactone
|
||||
16: '[r16][#8][#6](=[#8])', # 16-membered ring with lactone
|
||||
17: '[r17][#8][#6](=[#8])', # 17-membered ring with lactone
|
||||
18: '[r18][#8][#6](=[#8])', # 18-membered ring with lactone
|
||||
19: '[r19][#8][#6](=[#8])', # 19-membered ring with lactone
|
||||
20: '[r20][#8][#6](=[#8])', # 20-membered ring with lactone
|
||||
}
|
||||
|
||||
def standardize_molecule(self, mol: Chem.Mol) -> Optional[Chem.Mol]:
|
||||
"""Standardize molecule using RDKit standardization."""
|
||||
try:
|
||||
# Remove metals
|
||||
mol = self.standardizer.Disconnect(mol)
|
||||
# Normalize
|
||||
mol = rdMolStandardize.Normalize(mol)
|
||||
# Remove fragments
|
||||
mol = rdMolStandardize.FragmentParent(mol)
|
||||
# Neutralize charges
|
||||
mol = rdMolStandardize.ChargeParent(mol)
|
||||
return mol
|
||||
except:
|
||||
return None
|
||||
|
||||
def ecfp4_fingerprint(self, smiles: str) -> Optional[np.ndarray]:
|
||||
"""Generate ECFP4 fingerprint from SMILES string using newer RDKit API."""
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
|
||||
# Standardize molecule
|
||||
mol = self.standardize_molecule(mol)
|
||||
if mol is None:
|
||||
return None
|
||||
|
||||
# Generate Morgan fingerprint using the newer API to avoid deprecation warnings
|
||||
from rdkit.Chem import rdFingerprintGenerator
|
||||
generator = rdFingerprintGenerator.GetMorganGenerator(
|
||||
radius=self.radius,
|
||||
fpSize=self.n_bits,
|
||||
includeChirality=self.chirality
|
||||
)
|
||||
bv = generator.GetFingerprint(mol)
|
||||
|
||||
# Convert to numpy array
|
||||
arr = np.zeros((self.n_bits,), dtype=np.uint8)
|
||||
DataStructs.ConvertToNumpyArray(bv, arr)
|
||||
return arr
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing SMILES {smiles[:50]}...: {e}")
|
||||
return None
|
||||
|
||||
def detect_ring_number(self, smiles: str) -> int:
|
||||
"""Detect the ring number in macrolactone molecule using SMARTS patterns."""
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return 0
|
||||
|
||||
# Check each ring size pattern
|
||||
for ring_size, smarts in self.ring_smarts.items():
|
||||
query = Chem.MolFromSmarts(smarts)
|
||||
if query:
|
||||
matches = mol.GetSubstructMatches(query)
|
||||
if matches:
|
||||
return ring_size
|
||||
|
||||
# Alternative: check for any large ring with lactone
|
||||
generic_pattern = Chem.MolFromSmarts('[r{12-20}][#8][#6](=[#8])')
|
||||
if generic_pattern:
|
||||
matches = mol.GetSubstructMatches(generic_pattern)
|
||||
if matches:
|
||||
# Try to determine ring size from the first match
|
||||
for match in matches:
|
||||
# Get the ring atoms
|
||||
for atom_idx in match:
|
||||
atom = mol.GetAtomWithIdx(atom_idx)
|
||||
if atom.IsInRing():
|
||||
# Find the ring size
|
||||
for ring in atom.GetOwningMol().GetRingInfo().AtomRings():
|
||||
if atom_idx in ring:
|
||||
ring_size = len(ring)
|
||||
if 12 <= ring_size <= 20:
|
||||
return ring_size
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error detecting ring number for {smiles}: {e}")
|
||||
return 0
|
||||
|
||||
def generate_unique_id(self, index: int, existing_id: Optional[str] = None) -> str:
|
||||
"""Generate unique ID for molecule."""
|
||||
if existing_id and pd.notna(existing_id) and existing_id != '':
|
||||
return str(existing_id)
|
||||
else:
|
||||
return f"D{index:07d}"
|
||||
|
||||
def tanimoto_similarity(self, fp1: np.ndarray, fp2: np.ndarray) -> float:
|
||||
"""Calculate Tanimoto similarity between two fingerprints."""
|
||||
# Bit count
|
||||
bit_count1 = np.sum(fp1)
|
||||
bit_count2 = np.sum(fp2)
|
||||
common_bits = np.sum(fp1 & fp2)
|
||||
|
||||
if bit_count1 + bit_count2 - common_bits == 0:
|
||||
return 0.0
|
||||
|
||||
return common_bits / (bit_count1 + bit_count2 - common_bits)
|
||||
|
||||
def find_neighbors(self, X: np.ndarray, k: int = 15, batch_size: int = 1000) -> List[str]:
|
||||
"""Find k nearest neighbors for each molecule based on Tanimoto similarity."""
|
||||
n_samples = X.shape[0]
|
||||
neighbors = []
|
||||
|
||||
# Progress bar
|
||||
if HAS_TQDM:
|
||||
pbar = tqdm(total=n_samples, desc="Finding neighbors")
|
||||
|
||||
for i in range(n_samples):
|
||||
similarities = []
|
||||
|
||||
# Batch processing for memory efficiency
|
||||
for j in range(0, n_samples, batch_size):
|
||||
end_j = min(j + batch_size, n_samples)
|
||||
batch_X = X[j:end_j]
|
||||
|
||||
# Calculate similarities for this batch
|
||||
for batch_idx, fp in enumerate(batch_X):
|
||||
orig_idx = j + batch_idx
|
||||
if i != orig_idx:
|
||||
sim = self.tanimoto_similarity(X[i], fp)
|
||||
similarities.append((orig_idx, sim))
|
||||
|
||||
# Sort by similarity (descending)
|
||||
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Get top k neighbors
|
||||
top_neighbors = [str(idx) for idx, _ in similarities[:k]]
|
||||
neighbors.append(','.join(top_neighbors))
|
||||
|
||||
if HAS_TQDM:
|
||||
pbar.update(1)
|
||||
|
||||
if HAS_TQDM:
|
||||
pbar.close()
|
||||
|
||||
return neighbors
|
||||
|
||||
def perform_umap(self, X: np.ndarray, n_neighbors: int = 30,
|
||||
min_dist: float = 0.1, metric: str = 'jaccard') -> np.ndarray:
|
||||
"""Perform UMAP dimensionality reduction."""
|
||||
reducer = umap.UMAP(
|
||||
n_neighbors=n_neighbors,
|
||||
min_dist=min_dist,
|
||||
metric=metric,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
return reducer.fit_transform(X)
|
||||
|
||||
def process_dataframe(self, df: pd.DataFrame, smiles_col: str = 'smiles',
|
||||
id_col: Optional[str] = None, max_molecules: Optional[int] = None) -> pd.DataFrame:
|
||||
"""Process dataframe with SMILES strings."""
|
||||
print(f"Processing {len(df)} molecules...")
|
||||
|
||||
# Limit molecules if requested
|
||||
if max_molecules:
|
||||
df = df.head(max_molecules)
|
||||
print(f"Limited to {max_molecules} molecules")
|
||||
|
||||
# Ensure we have a smiles column
|
||||
if smiles_col not in df.columns:
|
||||
raise ValueError(f"Column '{smiles_col}' not found in dataframe")
|
||||
|
||||
# Create a working copy
|
||||
result_df = df.copy()
|
||||
|
||||
# Generate unique IDs if needed
|
||||
if id_col and id_col in df.columns:
|
||||
result_df['molecule_id'] = [self.generate_unique_id(i, existing_id)
|
||||
for i, existing_id in enumerate(result_df[id_col])]
|
||||
else:
|
||||
result_df['molecule_id'] = [self.generate_unique_id(i)
|
||||
for i in range(len(result_df))]
|
||||
|
||||
# Process fingerprints
|
||||
print("Generating ECFP4 fingerprints...")
|
||||
fingerprints = []
|
||||
valid_indices = []
|
||||
|
||||
# Progress tracking
|
||||
iterator = enumerate(result_df[smiles_col])
|
||||
if HAS_TQDM:
|
||||
iterator = tqdm(iterator, total=len(result_df), desc="Processing fingerprints")
|
||||
|
||||
for idx, smiles in iterator:
|
||||
if pd.notna(smiles) and smiles != '':
|
||||
fp = self.ecfp4_fingerprint(smiles)
|
||||
if fp is not None:
|
||||
fingerprints.append(fp)
|
||||
valid_indices.append(idx)
|
||||
else:
|
||||
print(f"Failed to generate fingerprint for index {idx}: {smiles[:50]}...")
|
||||
else:
|
||||
print(f"Invalid SMILES at index {idx}")
|
||||
|
||||
# Filter dataframe to valid molecules only
|
||||
result_df = result_df.iloc[valid_indices].reset_index(drop=True)
|
||||
|
||||
if not fingerprints:
|
||||
raise ValueError("No valid fingerprints generated")
|
||||
|
||||
# Convert fingerprints to numpy array
|
||||
X = np.array(fingerprints)
|
||||
print(f"Generated fingerprints for {len(fingerprints)} molecules")
|
||||
|
||||
# Detect ring numbers
|
||||
print("Detecting ring numbers...")
|
||||
ring_numbers = []
|
||||
|
||||
iterator = result_df[smiles_col]
|
||||
if HAS_TQDM:
|
||||
iterator = tqdm(iterator, desc="Detecting rings")
|
||||
|
||||
for smiles in iterator:
|
||||
ring_num = self.detect_ring_number(smiles)
|
||||
ring_numbers.append(ring_num)
|
||||
|
||||
result_df['ring_num'] = ring_numbers
|
||||
|
||||
# Perform UMAP
|
||||
print("Performing UMAP dimensionality reduction...")
|
||||
embedding = self.perform_umap(X)
|
||||
result_df['projection_x'] = embedding[:, 0]
|
||||
result_df['projection_y'] = embedding[:, 1]
|
||||
|
||||
# Find neighbors for embedding-atlas
|
||||
print("Finding nearest neighbors...")
|
||||
neighbors = self.find_neighbors(X, k=15)
|
||||
result_df['neighbors'] = neighbors
|
||||
|
||||
# Add fingerprint information
|
||||
result_df['fingerprint_bits'] = [fp.tolist() for fp in fingerprints]
|
||||
|
||||
return result_df
|
||||
|
||||
def create_visualization(self, df: pd.DataFrame, output_path: str):
|
||||
"""Create visualization of the UMAP embedding."""
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# Color by ring number
|
||||
scatter = plt.scatter(df['projection_x'], df['projection_y'],
|
||||
c=df['ring_num'], cmap='viridis', alpha=0.6, s=30)
|
||||
|
||||
plt.colorbar(scatter, label='Ring Number')
|
||||
plt.xlabel('UMAP 1')
|
||||
plt.ylabel('UMAP 2')
|
||||
plt.title('Macrolactone Molecules - ECFP4 + UMAP Visualization')
|
||||
|
||||
# Add some annotations for ring numbers
|
||||
for ring_num in sorted(df['ring_num'].unique()):
|
||||
if ring_num > 0:
|
||||
subset = df[df['ring_num'] == ring_num]
|
||||
if len(subset) > 0:
|
||||
center_x = subset['projection_x'].mean()
|
||||
center_y = subset['projection_y'].mean()
|
||||
plt.annotate(f'{ring_num} ring', (center_x, center_y),
|
||||
fontsize=10, fontweight='bold')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"Visualization saved to {output_path}")
|
||||
|
||||
def main():
|
||||
"""Main function to run the processing pipeline."""
|
||||
|
||||
parser = argparse.ArgumentParser(description='ECFP4 + UMAP for Macrolactone Molecules')
|
||||
parser.add_argument('--input', '-i', required=True,
|
||||
help='Input CSV file path')
|
||||
parser.add_argument('--output', '-o', required=True,
|
||||
help='Output CSV file path')
|
||||
parser.add_argument('--smiles-col', default='smiles',
|
||||
help='Name of SMILES column (default: smiles)')
|
||||
parser.add_argument('--id-col', default=None,
|
||||
help='Name of ID column (optional)')
|
||||
parser.add_argument('--visualization', '-v', default='umap_visualization.png',
|
||||
help='Output visualization file path')
|
||||
parser.add_argument('--max-molecules', type=int, default=None,
|
||||
help='Maximum number of molecules to process (for testing)')
|
||||
parser.add_argument('--launch-atlas', action='store_true',
|
||||
help='Launch embedding-atlas process')
|
||||
parser.add_argument('--atlas-port', type=int, default=8080,
|
||||
help='Port for embedding-atlas server')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize processor
|
||||
processor = MacrolactoneProcessor(n_bits=2048, radius=2, chirality=True)
|
||||
|
||||
# Load data
|
||||
print(f"Loading data from {args.input}")
|
||||
try:
|
||||
df = pd.read_csv(args.input)
|
||||
print(f"Loaded {len(df)} molecules")
|
||||
print(f"Columns: {list(df.columns)}")
|
||||
except Exception as e:
|
||||
print(f"Error loading data: {e}")
|
||||
return 1
|
||||
|
||||
# Process dataframe
|
||||
try:
|
||||
processed_df = processor.process_dataframe(df,
|
||||
smiles_col=args.smiles_col,
|
||||
id_col=args.id_col,
|
||||
max_molecules=args.max_molecules)
|
||||
print(f"Successfully processed {len(processed_df)} molecules")
|
||||
except Exception as e:
|
||||
print(f"Error processing data: {e}")
|
||||
return 1
|
||||
|
||||
# Save results
|
||||
try:
|
||||
processed_df.to_csv(args.output, index=False)
|
||||
print(f"Results saved to {args.output}")
|
||||
except Exception as e:
|
||||
print(f"Error saving results: {e}")
|
||||
return 1
|
||||
|
||||
# Create visualization
|
||||
try:
|
||||
processor.create_visualization(processed_df, args.visualization)
|
||||
except Exception as e:
|
||||
print(f"Error creating visualization: {e}")
|
||||
|
||||
# Launch embedding-atlas if requested
|
||||
if args.launch_atlas:
|
||||
print("Launching embedding-atlas process...")
|
||||
try:
|
||||
# Prepare command for embedding-atlas
|
||||
cmd = [
|
||||
'embedding-atlas', 'data', args.output,
|
||||
'--text', args.smiles_col,
|
||||
'--port', str(args.atlas_port),
|
||||
'--neighbors', 'neighbors',
|
||||
'--x', 'projection_x',
|
||||
'--y', 'projection_y'
|
||||
]
|
||||
|
||||
print(f"Running command: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("Embedding-atlas process launched successfully")
|
||||
print(f"Access the visualization at: http://localhost:{args.atlas_port}")
|
||||
else:
|
||||
print(f"Error launching embedding-atlas: {result.stderr}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print("embedding-atlas command not found. Please install it first.")
|
||||
print("You can install it with: pip install embedding-atlas")
|
||||
except Exception as e:
|
||||
print(f"Error launching embedding-atlas: {e}")
|
||||
|
||||
print("Processing complete!")
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
5
src/embedding_backend/__init__.py
Normal file
5
src/embedding_backend/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Embedding Atlas session orchestrator package."""
|
||||
|
||||
from .main import create_app, get_mcp_server
|
||||
|
||||
__all__ = ["create_app", "get_mcp_server"]
|
||||
90
src/embedding_backend/config.py
Normal file
90
src/embedding_backend/config.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Configuration for the embedding session orchestrator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from the environment."""
|
||||
|
||||
docker_url: str = Field(
|
||||
default="unix:///var/run/docker.sock",
|
||||
description="Docker socket URL accessed by the orchestrator.",
|
||||
)
|
||||
container_image: str = Field(
|
||||
default="embedding-atlas:latest",
|
||||
description="Default Docker image used for embedding-atlas sessions.",
|
||||
)
|
||||
container_name_prefix: str = Field(
|
||||
default="embedding-atlas",
|
||||
description="Prefix applied to managed container names.",
|
||||
)
|
||||
container_host: str = Field(
|
||||
default="0.0.0.0",
|
||||
description="Host binding used inside managed containers.",
|
||||
)
|
||||
api_host: str = Field(
|
||||
default="0.0.0.0",
|
||||
description="Listening host for the FastAPI server.",
|
||||
)
|
||||
api_port: int = Field(
|
||||
default=9000,
|
||||
description="Listening port for the FastAPI server.",
|
||||
)
|
||||
default_port: int = Field(
|
||||
default=5055,
|
||||
description="Fallback embedding-atlas port when no override is provided.",
|
||||
)
|
||||
port_range_start: int = Field(
|
||||
default=6000,
|
||||
description="Lower bound for dynamically assigned ports.",
|
||||
)
|
||||
port_range_end: int = Field(
|
||||
default=6999,
|
||||
description="Upper bound for dynamically assigned ports (inclusive).",
|
||||
)
|
||||
session_root: Path = Field(
|
||||
default=Path("runtime/sessions"),
|
||||
description="Directory used to persist downloaded datasets per session.",
|
||||
)
|
||||
download_chunk_size: int = Field(
|
||||
default=64 * 1024,
|
||||
description="Chunk size in bytes when streaming remote datasets.",
|
||||
)
|
||||
download_timeout_seconds: int = Field(
|
||||
default=300,
|
||||
description="Timeout applied to dataset downloads.",
|
||||
)
|
||||
auto_remove_seconds: int = Field(
|
||||
default=36000,
|
||||
description="Default inactivity window before a session is garbage-collected (10h).",
|
||||
)
|
||||
cleanup_interval_seconds: int = Field(
|
||||
default=600,
|
||||
description="Background cleanup cadence in seconds.",
|
||||
)
|
||||
container_network: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional Docker network to attach managed containers to.",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="EMBEDDING_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
)
|
||||
|
||||
def prepare(self) -> None:
|
||||
"""Ensure directories referenced by the settings exist."""
|
||||
|
||||
self.session_root = self.session_root.resolve()
|
||||
self.session_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings.prepare()
|
||||
376
src/embedding_backend/docker_manager.py
Normal file
376
src/embedding_backend/docker_manager.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""Asynchronous session manager backed by Docker containers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
from contextlib import suppress
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
import docker
|
||||
import httpx
|
||||
from docker.errors import APIError, NotFound
|
||||
from docker.models.containers import Container
|
||||
|
||||
from .config import Settings, settings
|
||||
from .models import (
|
||||
DATASET_LABEL,
|
||||
EXPIRY_LABEL,
|
||||
MANAGED_LABEL,
|
||||
SESSION_ID_LABEL,
|
||||
SessionCreateRequest,
|
||||
SessionRecord,
|
||||
default_labels,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionError(RuntimeError):
|
||||
"""Base class for orchestration failures."""
|
||||
|
||||
|
||||
class SessionExistsError(SessionError):
|
||||
"""Raised when a duplicate session identifier is requested."""
|
||||
|
||||
|
||||
class SessionNotFoundError(SessionError):
|
||||
"""Raised when attempting to operate on a missing session."""
|
||||
|
||||
|
||||
class DatasetDownloadError(SessionError):
|
||||
"""Raised when the dataset cannot be fetched."""
|
||||
|
||||
|
||||
class PortUnavailableError(SessionError):
|
||||
"""Raised when the requested port is already in use."""
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Coordinate embedding-atlas containers per incoming request."""
|
||||
|
||||
def __init__(self, app_settings: Settings | None = None) -> None:
|
||||
self.settings = app_settings or settings
|
||||
self.client = docker.DockerClient(base_url=self.settings.docker_url)
|
||||
self._sessions: Dict[str, SessionRecord] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Initialise caches and background cleanup."""
|
||||
|
||||
await self._load_existing_sessions()
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop background tasks and close Docker handles."""
|
||||
|
||||
if self._cleanup_task is not None:
|
||||
self._cleanup_task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await self._cleanup_task
|
||||
self._cleanup_task = None
|
||||
await asyncio.to_thread(self.client.close)
|
||||
|
||||
async def create_session(self, payload: SessionCreateRequest) -> SessionRecord:
|
||||
"""Download the dataset and spin up an embedding-atlas container."""
|
||||
|
||||
session_id = payload.normalised_session_id()
|
||||
async with self._lock:
|
||||
if session_id in self._sessions:
|
||||
raise SessionExistsError(f"Session {session_id} already active.")
|
||||
|
||||
host = payload.host or self.settings.container_host
|
||||
if payload.port is not None:
|
||||
is_free = await asyncio.to_thread(self._port_available, payload.port)
|
||||
if not is_free:
|
||||
raise PortUnavailableError(f"Requested port {payload.port} is unavailable")
|
||||
port = payload.port
|
||||
else:
|
||||
port = await asyncio.to_thread(self._allocate_port)
|
||||
expiry_seconds = payload.auto_remove_seconds or self.settings.auto_remove_seconds
|
||||
expires_at = datetime.now(tz=timezone.utc) + timedelta(seconds=expiry_seconds)
|
||||
dataset_path = await self._download_dataset(session_id, payload)
|
||||
|
||||
try:
|
||||
record = await asyncio.to_thread(
|
||||
self._run_container,
|
||||
session_id,
|
||||
dataset_path,
|
||||
payload,
|
||||
host,
|
||||
port,
|
||||
expires_at,
|
||||
)
|
||||
except Exception:
|
||||
# Drop downloaded artefacts if startup failed
|
||||
with suppress(Exception):
|
||||
self._purge_dataset(dataset_path)
|
||||
raise
|
||||
|
||||
async with self._lock:
|
||||
self._sessions[session_id] = record
|
||||
return record
|
||||
|
||||
async def close_session(self, session_id: str) -> bool:
|
||||
"""Terminate the associated container and remove cached data."""
|
||||
|
||||
record = await self._get_session(session_id)
|
||||
await asyncio.to_thread(self._stop_container, record.container_name)
|
||||
await asyncio.to_thread(self._purge_dataset, Path(record.dataset_path))
|
||||
async with self._lock:
|
||||
self._sessions.pop(session_id, None)
|
||||
return True
|
||||
|
||||
async def list_sessions(self) -> Iterable[SessionRecord]:
|
||||
"""Return the internal session registry."""
|
||||
|
||||
async with self._lock:
|
||||
return list(self._sessions.values())
|
||||
|
||||
async def cleanup_expired(self) -> None:
|
||||
"""Stop containers that exceeded their TTL."""
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
async with self._lock:
|
||||
sessions = list(self._sessions.values())
|
||||
for record in sessions:
|
||||
if record.expires_at <= now:
|
||||
logger.info("Session %s expired; stopping container %s", record.session_id, record.container_name)
|
||||
with suppress(SessionError, APIError, NotFound):
|
||||
await self.close_session(record.session_id)
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Periodic cleanup coroutine."""
|
||||
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(self.settings.cleanup_interval_seconds)
|
||||
await self.cleanup_expired()
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Cleanup loop cancelled")
|
||||
raise
|
||||
|
||||
async def _get_session(self, session_id: str) -> SessionRecord:
|
||||
async with self._lock:
|
||||
record = self._sessions.get(session_id)
|
||||
if record is None:
|
||||
raise SessionNotFoundError(f"Session {session_id} not found")
|
||||
return record
|
||||
|
||||
async def _load_existing_sessions(self) -> None:
|
||||
"""Recover managed containers into the local cache on startup."""
|
||||
|
||||
def _list() -> list[Container]:
|
||||
return self.client.containers.list(all=True, filters={"label": MANAGED_LABEL})
|
||||
|
||||
containers = await asyncio.to_thread(_list)
|
||||
restored: Dict[str, SessionRecord] = {}
|
||||
for container in containers:
|
||||
labels = container.labels or {}
|
||||
session_id = labels.get(SESSION_ID_LABEL)
|
||||
if not session_id:
|
||||
continue
|
||||
try:
|
||||
port = self._extract_port(container)
|
||||
except ValueError:
|
||||
logger.warning("Unable to restore port mapping for container %s", container.name)
|
||||
continue
|
||||
expires_at = self._parse_expiry(labels.get(EXPIRY_LABEL))
|
||||
dataset_path = labels.get(DATASET_LABEL, "")
|
||||
restored[session_id] = SessionRecord(
|
||||
session_id=session_id,
|
||||
container_id=container.id,
|
||||
container_name=container.name,
|
||||
port=port,
|
||||
host=self.settings.container_host,
|
||||
started_at=self._parse_created(container.attrs.get("Created")),
|
||||
expires_at=expires_at,
|
||||
dataset_path=dataset_path,
|
||||
)
|
||||
async with self._lock:
|
||||
self._sessions.update(restored)
|
||||
if restored:
|
||||
logger.info("Restored %d managed sessions from existing containers", len(restored))
|
||||
|
||||
def _run_container(
|
||||
self,
|
||||
session_id: str,
|
||||
dataset_path: Path,
|
||||
payload: SessionCreateRequest,
|
||||
host: str,
|
||||
port: int,
|
||||
expires_at: datetime,
|
||||
) -> SessionRecord:
|
||||
container_name = f"{self.settings.container_name_prefix}_{session_id}"
|
||||
self._guard_duplicate_container(container_name)
|
||||
|
||||
bind_root = dataset_path.parent
|
||||
container_dataset_path = Path("/session") / dataset_path.name
|
||||
|
||||
command = [
|
||||
"embedding-atlas",
|
||||
str(container_dataset_path),
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
"--no-auto-port",
|
||||
]
|
||||
if payload.extra_args:
|
||||
command.extend(payload.extra_args)
|
||||
|
||||
env = payload.environment.copy()
|
||||
env.setdefault("SESSION_ID", session_id)
|
||||
|
||||
labels = default_labels(session_id, expires_at, str(dataset_path))
|
||||
labels.update(payload.labels)
|
||||
|
||||
logger.info(
|
||||
"Starting embedding-atlas container %s on port %s (session %s)",
|
||||
container_name,
|
||||
port,
|
||||
session_id,
|
||||
)
|
||||
|
||||
port_bindings = {f"{port}/tcp": port}
|
||||
container = self.client.containers.run(
|
||||
payload.image or self.settings.container_image,
|
||||
command,
|
||||
name=container_name,
|
||||
detach=True,
|
||||
environment=env,
|
||||
labels=labels,
|
||||
volumes={
|
||||
str(bind_root): {"bind": "/session", "mode": "ro"}
|
||||
},
|
||||
ports=port_bindings,
|
||||
network=self.settings.container_network,
|
||||
)
|
||||
|
||||
return SessionRecord(
|
||||
session_id=session_id,
|
||||
container_id=container.id,
|
||||
container_name=container.name,
|
||||
port=port,
|
||||
host=host,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
expires_at=expires_at,
|
||||
dataset_path=str(dataset_path),
|
||||
)
|
||||
|
||||
def _stop_container(self, container_name: str) -> None:
|
||||
try:
|
||||
container = self.client.containers.get(container_name)
|
||||
except NotFound as exc:
|
||||
raise SessionNotFoundError(f"Container {container_name} not found") from exc
|
||||
container.stop(timeout=30)
|
||||
container.remove(v=True)
|
||||
logger.info("Removed container %s", container_name)
|
||||
|
||||
def _purge_dataset(self, dataset_path: Path) -> None:
|
||||
try:
|
||||
dataset_path.unlink(missing_ok=True)
|
||||
parent = dataset_path.parent
|
||||
if parent.exists() and not any(parent.iterdir()):
|
||||
parent.rmdir()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to clean dataset path %s: %s", dataset_path, exc)
|
||||
|
||||
async def _download_dataset(self, session_id: str, payload: SessionCreateRequest) -> Path:
|
||||
filename = payload.input_filename or Path(payload.data_url.path).name
|
||||
if not filename:
|
||||
filename = f"dataset-{session_id}"
|
||||
session_dir = self.settings.session_root / session_id
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
target = session_dir / filename
|
||||
|
||||
logger.info("Downloading dataset for session %s from %s", session_id, payload.data_url)
|
||||
timeout = httpx.Timeout(self.settings.download_timeout_seconds)
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
|
||||
try:
|
||||
async with client.stream("GET", str(payload.data_url)) as response:
|
||||
response.raise_for_status()
|
||||
with target.open("wb") as file_handle:
|
||||
async for chunk in response.aiter_bytes(chunk_size=self.settings.download_chunk_size):
|
||||
file_handle.write(chunk)
|
||||
except Exception as exc:
|
||||
raise DatasetDownloadError(f"Failed to fetch dataset: {exc}") from exc
|
||||
return target
|
||||
|
||||
def _allocate_port(self) -> int:
|
||||
if self.settings.port_range_start >= self.settings.port_range_end:
|
||||
return self._random_free_port()
|
||||
candidates = range(self.settings.port_range_start, self.settings.port_range_end + 1)
|
||||
for candidate in candidates:
|
||||
if self._port_available(candidate):
|
||||
return candidate
|
||||
return self._random_free_port()
|
||||
|
||||
@staticmethod
|
||||
def _random_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
@staticmethod
|
||||
def _port_available(port: int) -> bool:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
result = sock.connect_ex(("127.0.0.1", port))
|
||||
return result != 0
|
||||
|
||||
def _guard_duplicate_container(self, container_name: str) -> None:
|
||||
try:
|
||||
container = self.client.containers.get(container_name)
|
||||
except NotFound:
|
||||
return
|
||||
raise SessionExistsError(
|
||||
f"Container name {container_name} already exists (id={container.id})."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_port(container: Container) -> int:
|
||||
ports = container.attrs.get("NetworkSettings", {}).get("Ports", {})
|
||||
if not ports:
|
||||
raise ValueError("No exposed ports found")
|
||||
# Expect single port binding
|
||||
(container_port, host_bindings), = ports.items()
|
||||
if not host_bindings:
|
||||
raise ValueError("Container port not published")
|
||||
host_port = host_bindings[0].get("HostPort")
|
||||
if host_port is None:
|
||||
raise ValueError("Missing host port")
|
||||
return int(host_port)
|
||||
|
||||
@staticmethod
|
||||
def _parse_expiry(value: Optional[str]) -> datetime:
|
||||
if not value:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
try:
|
||||
return datetime.fromtimestamp(int(value), tz=timezone.utc)
|
||||
except ValueError:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
@staticmethod
|
||||
def _parse_created(value: Optional[str]) -> datetime:
|
||||
if not value:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
try:
|
||||
return datetime.fromisoformat(value.replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SessionManager",
|
||||
"SessionError",
|
||||
"SessionExistsError",
|
||||
"SessionNotFoundError",
|
||||
"DatasetDownloadError",
|
||||
"PortUnavailableError",
|
||||
]
|
||||
120
src/embedding_backend/main.py
Normal file
120
src/embedding_backend/main.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Entry points for launching the orchestrator as FastAPI or FastMCP."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from .config import Settings, settings
|
||||
from .docker_manager import SessionManager
|
||||
from .routes import get_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app(app_settings: Optional[Settings] = None) -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
|
||||
app_settings = app_settings or settings
|
||||
manager = SessionManager(app_settings)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
await manager.start()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await manager.stop()
|
||||
|
||||
app = FastAPI(
|
||||
title="Embedding Atlas Orchestrator",
|
||||
version="0.1.0",
|
||||
description=(
|
||||
"Session orchestration API that launches short-lived embedding-atlas containers on demand."
|
||||
),
|
||||
docs_url="/docs",
|
||||
redoc_url=None,
|
||||
openapi_url="/openapi.json",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
@app.get("/", tags=["meta"], summary="Service metadata")
|
||||
async def root() -> dict[str, str]:
|
||||
return {"message": "Embedding Atlas session manager is running. Visit /docs for the Swagger UI."}
|
||||
|
||||
app.include_router(get_router(manager))
|
||||
app.state.session_manager = manager
|
||||
return app
|
||||
|
||||
|
||||
_app = create_app()
|
||||
app = _app
|
||||
_mcp_server = FastMCP.from_fastapi(_app, name="embedding-atlas-session-manager")
|
||||
mcp_server = _mcp_server
|
||||
|
||||
|
||||
def get_mcp_server() -> FastMCP:
|
||||
return mcp_server
|
||||
|
||||
|
||||
def run_api() -> None:
|
||||
"""Launch the FastAPI server via uvicorn."""
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=settings.api_host,
|
||||
port=settings.api_port,
|
||||
)
|
||||
|
||||
|
||||
async def _run_mcp_stdio_async() -> None:
|
||||
server = get_mcp_server()
|
||||
await server.run_stdio_async()
|
||||
|
||||
|
||||
def run_mcp_stdio() -> None:
|
||||
"""Launch the MCP server over stdio (experimental)."""
|
||||
|
||||
asyncio.run(_run_mcp_stdio_async())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""CLI entry point allowing API or MCP execution modes."""
|
||||
|
||||
parser = argparse.ArgumentParser(description="Embedding Atlas orchestrator")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["api", "mcp"],
|
||||
default="api",
|
||||
help="Select whether to launch the HTTP API or the MCP stdio server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default=settings.api_host,
|
||||
help="Override FastAPI bind host.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=settings.api_port,
|
||||
help="Override FastAPI bind port.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "api":
|
||||
logger.info("Starting FastAPI server on %s:%s", args.host, args.port)
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
else:
|
||||
logger.info("Starting FastMCP stdio server")
|
||||
run_mcp_stdio()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
122
src/embedding_backend/models.py
Normal file
122
src/embedding_backend/models.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Pydantic models for the orchestration API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
|
||||
class SessionCreateRequest(BaseModel):
|
||||
"""Incoming payload describing a requested embedding-atlas session."""
|
||||
|
||||
session_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional session identifier; generated when omitted.",
|
||||
)
|
||||
data_url: HttpUrl = Field(
|
||||
description="HTTPS location of the dataset to visualize.",
|
||||
)
|
||||
input_filename: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Filename to persist the downloaded dataset under.",
|
||||
)
|
||||
extra_args: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Additional CLI arguments forwarded to embedding-atlas.",
|
||||
)
|
||||
environment: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Environment variables injected into the managed container.",
|
||||
)
|
||||
labels: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Extra Docker labels assigned to the managed container.",
|
||||
)
|
||||
image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Override Docker image for the managed container.",
|
||||
)
|
||||
host: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Host binding forwarded to embedding-atlas (defaults to settings.container_host).",
|
||||
)
|
||||
port: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Preferred host port for the embedding-atlas UI.",
|
||||
)
|
||||
auto_remove_seconds: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Lifetime of the session before automatic cleanup triggers.",
|
||||
)
|
||||
|
||||
def normalised_session_id(self) -> str:
|
||||
return self.session_id or uuid4().hex
|
||||
|
||||
|
||||
class SessionRecord(BaseModel):
|
||||
"""Internal representation of a live session."""
|
||||
|
||||
session_id: str
|
||||
container_id: str
|
||||
container_name: str
|
||||
port: int
|
||||
host: str
|
||||
started_at: datetime
|
||||
expires_at: datetime
|
||||
dataset_path: str
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
"""HTTP response returned after a session is launched."""
|
||||
|
||||
session_id: str
|
||||
container_name: str
|
||||
endpoint: str
|
||||
port: int
|
||||
host: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class SessionCloseResponse(BaseModel):
|
||||
"""HTTP response returned after a session is terminated."""
|
||||
|
||||
session_id: str
|
||||
removed: bool
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
"""Shape returned when listing active sessions."""
|
||||
|
||||
sessions: List[SessionRecord]
|
||||
|
||||
|
||||
MANAGED_LABEL = "embedding-backend.managed"
|
||||
EXPIRY_LABEL = "embedding-backend.expires-at"
|
||||
SESSION_ID_LABEL = "embedding-backend.session-id"
|
||||
DATASET_LABEL = "embedding-backend.dataset-path"
|
||||
|
||||
|
||||
def default_labels(session_id: str, expires_at: datetime, dataset_path: str) -> Dict[str, str]:
|
||||
return {
|
||||
MANAGED_LABEL: "true",
|
||||
SESSION_ID_LABEL: session_id,
|
||||
EXPIRY_LABEL: str(int(expires_at.timestamp())),
|
||||
DATASET_LABEL: dataset_path,
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SessionCreateRequest",
|
||||
"SessionRecord",
|
||||
"SessionResponse",
|
||||
"SessionCloseResponse",
|
||||
"SessionListResponse",
|
||||
"default_labels",
|
||||
"MANAGED_LABEL",
|
||||
"EXPIRY_LABEL",
|
||||
"SESSION_ID_LABEL",
|
||||
"DATASET_LABEL",
|
||||
]
|
||||
71
src/embedding_backend/routes.py
Normal file
71
src/embedding_backend/routes.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""FastAPI routes exposing the session management API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from .docker_manager import (
|
||||
DatasetDownloadError,
|
||||
PortUnavailableError,
|
||||
SessionExistsError,
|
||||
SessionManager,
|
||||
SessionNotFoundError,
|
||||
)
|
||||
from .models import (
|
||||
SessionCloseResponse,
|
||||
SessionCreateRequest,
|
||||
SessionListResponse,
|
||||
SessionRecord,
|
||||
SessionResponse,
|
||||
)
|
||||
|
||||
|
||||
def get_router(manager: SessionManager) -> APIRouter:
|
||||
router = APIRouter(prefix="/sessions", tags=["sessions"])
|
||||
|
||||
async def get_manager() -> SessionManager:
|
||||
return manager
|
||||
|
||||
@router.post("/", response_model=SessionResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_session(
|
||||
payload: SessionCreateRequest,
|
||||
manager: SessionManager = Depends(get_manager),
|
||||
) -> SessionResponse:
|
||||
try:
|
||||
record = await manager.create_session(payload)
|
||||
except SessionExistsError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||
except DatasetDownloadError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
||||
except PortUnavailableError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||
endpoint = f"http://{record.host}:{record.port}"
|
||||
return SessionResponse(
|
||||
session_id=record.session_id,
|
||||
container_name=record.container_name,
|
||||
endpoint=endpoint,
|
||||
port=record.port,
|
||||
host=record.host,
|
||||
expires_at=record.expires_at,
|
||||
)
|
||||
|
||||
@router.delete("/{session_id}", response_model=SessionCloseResponse)
|
||||
async def close_session(
|
||||
session_id: str,
|
||||
manager: SessionManager = Depends(get_manager),
|
||||
) -> SessionCloseResponse:
|
||||
try:
|
||||
await manager.close_session(session_id)
|
||||
except SessionNotFoundError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
return SessionCloseResponse(session_id=session_id, removed=True)
|
||||
|
||||
@router.get("/", response_model=SessionListResponse)
|
||||
async def list_sessions(manager: SessionManager = Depends(get_manager)) -> SessionListResponse:
|
||||
sessions = await manager.list_sessions()
|
||||
return SessionListResponse(sessions=list(sessions))
|
||||
|
||||
return router
|
||||
|
||||
|
||||
__all__ = ["get_router"]
|
||||
Reference in New Issue
Block a user