update
This commit is contained in:
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# SCM syntax highlighting & preventing 3-way merges
|
||||||
|
pixi.lock merge=binary linguist-language=YAML linguist-generated=true
|
||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,2 +1,7 @@
|
|||||||
.venv/
|
.venv/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
__pycache__/
|
||||||
|
*.pyc# pixi environments
|
||||||
|
.pixi/*
|
||||||
|
!.pixi/config.toml
|
||||||
|
*.egg-info/
|
||||||
115
README.md
115
README.md
@@ -33,6 +33,100 @@ export HF_HUB_OFFLINE=1
|
|||||||
export HF_ENDPOINT=https://hf-mirror.com
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 会话编排服务(FastAPI / MCP)
|
||||||
|
|
||||||
|
使用 `uv run embedding-backend-api` 可以启动一个同时兼容 FastAPI 与 FastMCP 的后端服务。该服务监听 `/sessions` 路径,负责按需拉起 `embedding-atlas` 容器并在 10 小时后自动清理。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run embedding-backend-api
|
||||||
|
# 或以 MCP 模式启动(stdio)
|
||||||
|
uv run embedding-backend-mcp
|
||||||
|
```
|
||||||
|
|
||||||
|
### REST API 用法
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:9000/sessions \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"data_url": "https://example.com/data.csv",
|
||||||
|
"extra_args": ["--text", "smiles"]
|
||||||
|
}'
|
||||||
|
|
||||||
|
# 关闭会话
|
||||||
|
curl -X DELETE http://localhost:9000/sessions/<session_id>
|
||||||
|
|
||||||
|
# 查看当前会话
|
||||||
|
curl http://localhost:9000/sessions
|
||||||
|
```
|
||||||
|
|
||||||
|
请求体支持 `session_id`、`port`、`auto_remove_seconds`、`environment` 等字段;省略 `session_id` 时会自动生成并在响应中返回。返回结果包含容器名称与可访问的前端地址,供 FastAPI 或 MCP 客户端转发给前端使用。
|
||||||
|
|
||||||
|
如果希望以交互方式测试 API,可在浏览器访问 `http://localhost:9000/docs` 打开自动生成的 Swagger UI。
|
||||||
|
|
||||||
|
#### 请求体字段说明
|
||||||
|
|
||||||
|
- `session_id`:可选,自定义的会话 ID;省略时系统自动生成。
|
||||||
|
- `data_url`:必填,指向 CSV/Parquet 的 HTTPS 链接,会被下载到 orchestrator 与 DinD 共享的 `/sessions/<session_id>/` 中。
|
||||||
|
- `input_filename`:可选,保存到共享卷时使用的文件名;默认根据 URL 推断。
|
||||||
|
- `extra_args`:可选,附加给 `embedding-atlas` CLI 的参数数组,例如 `["--text", "smiles"]`。
|
||||||
|
- `environment`:可选,注入到容器内的环境变量映射。
|
||||||
|
- `labels`:可选,附加到容器上的 Docker labels(在默认的 `embedding-backend.*` 标签之外),便于自定义监控或追踪。
|
||||||
|
- `image`:可选,覆盖默认的 `embedding-atlas` 镜像名。
|
||||||
|
- `host`:可选,传递给 `embedding-atlas --host` 的值,默认 `0.0.0.0`。
|
||||||
|
- `port`:可选,显式指定宿主机端口;若缺省会自动在配置的区间内分配。
|
||||||
|
- `auto_remove_seconds`:可选,覆盖默认的 10 小时存活时间。
|
||||||
|
|
||||||
|
示例请求:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:9000/sessions \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"session_id": "demo-session",
|
||||||
|
"data_url": "https://example.com/data.csv",
|
||||||
|
"input_filename": "data.csv",
|
||||||
|
"extra_args": ["--text", "smiles"],
|
||||||
|
"environment": {"HF_ENDPOINT": "https://hf-mirror.com"},
|
||||||
|
"labels": {"team": "chem"},
|
||||||
|
"auto_remove_seconds": 7200
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### GET /sessions 响应示例
|
||||||
|
|
||||||
|
无查询参数,返回当前所有会话的列表:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sessions": [
|
||||||
|
{
|
||||||
|
"session_id": "demo-session",
|
||||||
|
"container_id": "e556c0f1c35b...",
|
||||||
|
"container_name": "embedding-atlas_demo-session",
|
||||||
|
"port": 6000,
|
||||||
|
"host": "0.0.0.0",
|
||||||
|
"started_at": "2025-09-22T14:37:25.206038Z",
|
||||||
|
"expires_at": "2025-09-23T00:37:24.416241Z",
|
||||||
|
"dataset_path": "/sessions/demo-session/data.csv"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `sessions` 数组为空时表示当前没有存活的容器。
|
||||||
|
- `dataset_path` 指向共享卷内对应数据集的绝对路径。
|
||||||
|
|
||||||
|
### MCP 集成
|
||||||
|
|
||||||
|
根目录的 `fastmcp.json` 示例可直接将本项目注册为 MCP 工具:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run embedding-backend-mcp
|
||||||
|
```
|
||||||
|
|
||||||
|
FastMCP 客户端加载该配置后,可用标准 MCP 协议转发同一套 REST 接口,从而与传统后端保持一致的行为。
|
||||||
|
|
||||||
## 命令行生成嵌入可视化交互
|
## 命令行生成嵌入可视化交互
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -79,3 +173,24 @@ uv run python script/merge_splits.py --input-dir splits_v2/ --output data/drugba
|
|||||||
```bash
|
```bash
|
||||||
uv run embedding-atlas data/drugbank_split_merge.csv --text smiles
|
uv run embedding-atlas data/drugbank_split_merge.csv --text smiles
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 容器化部署
|
||||||
|
|
||||||
|
项目提供 `docker/` 目录用于快速启动后端:
|
||||||
|
|
||||||
|
1. 先构建可被 orchestrator 复用的 `embedding-atlas` 镜像:
|
||||||
|
```bash
|
||||||
|
docker build -f docker/embedding-atlas.Dockerfile -t embedding-atlas:latest .
|
||||||
|
```
|
||||||
|
2. 启动 DIND + orchestrator 组合:
|
||||||
|
```bash
|
||||||
|
docker compose -f docker/docker-compose.yml up --build
|
||||||
|
```
|
||||||
|
- `engine` 服务运行 `docker:dind`,对外暴露 `tcp://localhost:2375` 供 orchestrator 通过 socket 管理容器;
|
||||||
|
- 两个服务通过 `sessions-data` 卷共享 `/sessions` 目录,后端会把下载的数据放到这里;
|
||||||
|
- `orchestrator` 服务运行 FastAPI/FastMCP 后端,默认开放 `http://localhost:9000`;
|
||||||
|
- 会话缓存保存在 `sessions-data` 卷,可在容器重启时保留下载的数据集。
|
||||||
|
|
||||||
|
如需调整默认镜像或端口,可在 `docker/docker-compose.yml` 中覆盖 `EMBEDDING_*` 环境变量,或在部署时通过 `.env` 文件注入。
|
||||||
|
|
||||||
|
> 注意:由于 orchestrator 运行在容器内并通过 DinD 调度新容器,如需在宿主机直接访问 `embedding-atlas` 的 Web UI,需要确保相应端口从 `docker_engine_1` 转发到宿主,可按需求使用 `podman port` 或额外的反向代理。
|
||||||
|
|||||||
90
data/ring12_20/README.md
Normal file
90
data/ring12_20/README.md
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
Your Filtered Macrolactone Database
|
||||||
|
|
||||||
|
11036 compounds have been filtered from MacrolactoneDB based on your specified inputs.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text ecfp4_binary
|
||||||
|
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text tanimoto_top_neighbors
|
||||||
|
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text smiles
|
||||||
|
```
|
||||||
|
|
||||||
|
## 嵌入和投影优化
|
||||||
|
|
||||||
|
### projection_x 和 projection_y 的生成过程
|
||||||
|
|
||||||
|
UMAP 降维计算
|
||||||
|
这两个坐标是通过 _run_umap() 函数生成的,该函数使用 UMAP 算法将高维嵌入向量降维到 2D 空间 projection.py:64-88 。
|
||||||
|
|
||||||
|
具体流程如下:
|
||||||
|
|
||||||
|
计算最近邻 - 首先使用 nearest_neighbors() 计算每个点的 k 个最近邻 projection.py:76-83
|
||||||
|
UMAP 投影 - 然后使用预计算的邻居信息进行 UMAP 降维 projection.py:85-86
|
||||||
|
坐标分配 - 结果的第一列成为 projection_x,第二列成为 projection_y projection.py:259-260
|
||||||
|
默认参数设置
|
||||||
|
UMAP 算法使用以下默认参数:
|
||||||
|
|
||||||
|
邻居数量: 15 个最近邻 projection.py:74
|
||||||
|
距离度量: cosine 距离 projection.py:73
|
||||||
|
在不同数据类型中的应用
|
||||||
|
文本数据处理
|
||||||
|
对于您的 SMILES 分子数据,系统首先使用 SentenceTransformers 生成文本嵌入,然后通过 UMAP 降维 projection.py:251-260 。
|
||||||
|
|
||||||
|
预计算向量处理
|
||||||
|
如果您有预计算的 ECFP4 向量,系统会直接对这些向量进行 UMAP 降维 projection.py:311-318 。
|
||||||
|
|
||||||
|
可视化中的作用
|
||||||
|
在前端可视化界面中,这些坐标用作:
|
||||||
|
|
||||||
|
散点图的 X/Y 轴 - 每个数据点在 2D 空间中的位置
|
||||||
|
颜色编码的基础 - 可以根据坐标值进行颜色映射 embedding-atlas.md:68-70
|
||||||
|
演示数据示例
|
||||||
|
在项目的演示数据生成中,可以看到相同的处理流程:使用 SentenceTransformers 计算嵌入,然后通过 UMAP 生成 projection_x 和 projection_y 坐标 generate_demo_data.py:42-43 。
|
||||||
|
|
||||||
|
Notes
|
||||||
|
|
||||||
|
这些投影坐标的质量很大程度上取决于原始嵌入的质量和 UMAP 参数的选择。对于化学分子数据,使用专门的分子嵌入模型通常会产生更有意义的 2D 投影,其中化学结构相似的分子会在投影空间中聚集在一起。
|
||||||
|
|
||||||
|
### UMAP 参数调优
|
||||||
|
|
||||||
|
您可以通过调整 UMAP 参数来获得更好的可视化效果:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 调整邻居数量和距离参数
|
||||||
|
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text smiles \
|
||||||
|
--umap-n-neighbors 30 \
|
||||||
|
--umap-min-dist 0.1 \
|
||||||
|
--umap-metric cosine \
|
||||||
|
--umap-random-state 42
|
||||||
|
```
|
||||||
|
|
||||||
|
## 自定义嵌入模型
|
||||||
|
|
||||||
|
对于化学分子数据,您可能想使用专门的模型:
|
||||||
|
|
||||||
|
并且符合:
|
||||||
|
|
||||||
|
模型支持范围
|
||||||
|
embedding-atlas 支持两种类型的自定义模型:
|
||||||
|
|
||||||
|
文本嵌入模型
|
||||||
|
对于文本数据(如您的 SMILES 分子数据),系统使用 SentenceTransformers 库 projection.py:118-126 。这意味着您可以使用任何与 SentenceTransformers 兼容的 Hugging Face 模型。
|
||||||
|
|
||||||
|
图像嵌入模型
|
||||||
|
对于图像数据,系统使用 transformers 库的 pipeline 功能 projection.py:168-180 。
|
||||||
|
|
||||||
|
模型格式要求
|
||||||
|
SentenceTransformers 兼容性
|
||||||
|
文本模型必须与 SentenceTransformers 库兼容 projection.py:98-99 。这包括:
|
||||||
|
|
||||||
|
专门训练用于句子嵌入的模型
|
||||||
|
支持 .encode() 方法的模型
|
||||||
|
能够输出固定维度向量的模型
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run embedding-atlas data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet --text smiles \
|
||||||
|
--umap-n-neighbors 30 \
|
||||||
|
--umap-min-dist 0.1 \
|
||||||
|
--umap-metric cosine \
|
||||||
|
--umap-random-state 42
|
||||||
|
```
|
||||||
89
data/ring12_20/counts.txt
Normal file
89
data/ring12_20/counts.txt
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
Target Organisms
|
||||||
|
Homo sapiens 815
|
||||||
|
Homo sapiens, None 180
|
||||||
|
Plasmodium falciparum 161
|
||||||
|
Hepatitis C virus, None 112
|
||||||
|
Homo sapiens, Plasmodium falciparum 63
|
||||||
|
Oryctolagus cuniculus 62
|
||||||
|
Mus musculus 60
|
||||||
|
Toxoplasma gondii 39
|
||||||
|
Homo sapiens, Rattus norvegicus 27
|
||||||
|
Mus musculus, Homo sapiens 24
|
||||||
|
None, Rattus norvegicus 23
|
||||||
|
Human immunodeficiency virus 1 20
|
||||||
|
Hepatitis C virus 18
|
||||||
|
Rattus norvegicus 17
|
||||||
|
Homo sapiens, Sus scrofa 11
|
||||||
|
Homo sapiens, Chlorocebus aethiops 10
|
||||||
|
Serratia marcescens 9
|
||||||
|
Escherichia coli 8
|
||||||
|
Oryctolagus cuniculus, Homo sapiens 7
|
||||||
|
Streptococcus pneumoniae 6
|
||||||
|
Oryctolagus cuniculus, Staphylococcus aureus, Raoultella planticola, Bacillus subtilis, Mus musculus, Micrococcus luteus, None, Escherichia coli, Plasmodium falciparum, Streptococcus pneumoniae, Homo sapiens, Escherichia coli K-12, Toxoplasma gondii 6
|
||||||
|
Plasmodium falciparum K1 5
|
||||||
|
Bacillus anthracis 5
|
||||||
|
Mus musculus, Homo sapiens, None 5
|
||||||
|
Bacillus anthracis, Homo sapiens 4
|
||||||
|
Candida albicans, Cryptococcus neoformans, Aspergillus fumigatus 4
|
||||||
|
Mus musculus, None 4
|
||||||
|
Plasmodium falciparum, Homo sapiens, None 4
|
||||||
|
None, Homo sapiens, Plasmodium falciparum 3
|
||||||
|
Bacillus subtilis, Homo sapiens 3
|
||||||
|
Oryctolagus cuniculus, Homo sapiens, None 3
|
||||||
|
Sus scrofa, Mus musculus, None, Plasmodium falciparum, Homo sapiens, Rattus norvegicus 2
|
||||||
|
Homo sapiens, None, Rattus norvegicus 2
|
||||||
|
Cryptococcus neoformans 2
|
||||||
|
Homo sapiens, None, Chlorocebus aethiops 2
|
||||||
|
Staphylococcus aureus 2
|
||||||
|
Candida albicans, Cryptococcus neoformans, Mycobacterium intracellulare, Aspergillus fumigatus 2
|
||||||
|
Mus musculus, None, Human immunodeficiency virus 1 2
|
||||||
|
Escherichia coli (strain K12) 2
|
||||||
|
Plasmodium falciparum 3D7, Homo sapiens 2
|
||||||
|
Aspergillus fumigatus 1
|
||||||
|
Sus scrofa 1
|
||||||
|
Saccharomyces cerevisiae S288c, Human immunodeficiency virus 1, Human herpesvirus 1, Plasmodium falciparum, None, Homo sapiens, Rattus norvegicus 1
|
||||||
|
Hepatitis C virus, Homo sapiens, None 1
|
||||||
|
Plasmodium falciparum 3D7 1
|
||||||
|
Bacillus subtilis 1
|
||||||
|
Mus musculus, Homo sapiens, None, Saccharomyces cerevisiae 1
|
||||||
|
Chlorocebus aethiops 1
|
||||||
|
Homo sapiens, Escherichia coli K-12, None 1
|
||||||
|
Hepatitis C virus, Homo sapiens, None, Rattus norvegicus 1
|
||||||
|
None, Homo sapiens, Human herpesvirus 1 1
|
||||||
|
Homo sapiens, None, Trypanosoma brucei brucei 1
|
||||||
|
Homo sapiens, None, Cryptococcus neoformans 1
|
||||||
|
Homo sapiens, Rattus norvegicus, Human immunodeficiency virus 1 1
|
||||||
|
None, Plasmodium falciparum, Escherichia coli, Streptococcus pneumoniae, Naegleria fowleri, Homo sapiens, Streptococcus, Toxoplasma gondii 1
|
||||||
|
Giardia intestinalis, Trypanosoma cruzi, Equus caballus, Bos taurus, Mus musculus, None, Plasmodium falciparum, Chlorocebus aethiops, Homo sapiens 1
|
||||||
|
Plasmodium falciparum NF54, Trypanosoma cruzi, Trypanosoma brucei rhodesiense, Rattus norvegicus 1
|
||||||
|
None, Homo sapiens, Plasmodium falciparum K1, Plasmodium falciparum 1
|
||||||
|
Saccharomyces cerevisiae S288c, Homo sapiens, None, Saccharomyces cerevisiae, Phytophthora sojae 1
|
||||||
|
Bacillus subtilis, Homo sapiens, Schistosoma mansoni, Saccharomyces cerevisiae, Giardia intestinalis 1
|
||||||
|
Streptococcus, Homo sapiens, None 1
|
||||||
|
Mus musculus, Homo sapiens, Rattus norvegicus 1
|
||||||
|
Homo sapiens, Spinacia oleracea 1
|
||||||
|
Human immunodeficiency virus 1, Mus musculus, None, Hepatitis C virus, Homo sapiens, Rattus norvegicus 1
|
||||||
|
None, Plasmodium falciparum, Trypanosoma brucei rhodesiense 1
|
||||||
|
Hepatitis C virus, None, Rattus norvegicus 1
|
||||||
|
Homo sapiens, Equus caballus 1
|
||||||
|
Plasmodium falciparum NF54, Trypanosoma cruzi, Trypanosoma brucei rhodesiense 1
|
||||||
|
Schistosoma mansoni, Influenza A virus 1
|
||||||
|
Leishmania chagasi, Trypanosoma cruzi 1
|
||||||
|
Candida albicans, Cryptococcus neoformans 1
|
||||||
|
None, Plasmodium falciparum 1
|
||||||
|
Caenorhabditis elegans 1
|
||||||
|
Bos taurus, Sus scrofa 1
|
||||||
|
Plasmodium falciparum, Enterococcus faecium 1
|
||||||
|
Homo sapiens, Gallus gallus 1
|
||||||
|
Homo sapiens, Escherichia coli 1
|
||||||
|
Plasmodium falciparum, Homo sapiens, None, Rattus norvegicus, Schistosoma mansoni 1
|
||||||
|
Homo sapiens, None, Influenza A virus 1
|
||||||
|
Mycobacterium tuberculosis, None 1
|
||||||
|
Escherichia coli, Homo sapiens, Toxoplasma gondii, None, Streptococcus pneumoniae 1
|
||||||
|
Bacillus subtilis, Oryctolagus cuniculus, Homo sapiens, Schistosoma mansoni, Giardia intestinalis 1
|
||||||
|
Homo sapiens, None, Rattus norvegicus, Escherichia coli O157:H7 1
|
||||||
|
Giardia intestinalis, Schistosoma mansoni, Mus musculus, None, Homo sapiens, Saccharomyces cerevisiae 1
|
||||||
|
Trypanosoma cruzi 1
|
||||||
|
Influenza A virus 1
|
||||||
|
Escherichia coli K-12 1
|
||||||
|
Human herpesvirus 4 (strain B95-8) 1
|
||||||
11037
data/ring12_20/temp.csv
Normal file
11037
data/ring12_20/temp.csv
Normal file
File diff suppressed because one or more lines are too long
11037
data/ring12_20/temp_with_macrocycles.csv
Normal file
11037
data/ring12_20/temp_with_macrocycles.csv
Normal file
File diff suppressed because one or more lines are too long
11037
data/ring12_20/temp_with_macrocycles_with_ecfp4.csv
Normal file
11037
data/ring12_20/temp_with_macrocycles_with_ecfp4.csv
Normal file
File diff suppressed because one or more lines are too long
BIN
data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet
Normal file
BIN
data/ring12_20/temp_with_macrocycles_with_ecfp4.parquet
Normal file
Binary file not shown.
41
docker/Dockerfile
Normal file
41
docker/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# syntax=docker/dockerfile:1
|
||||||
|
|
||||||
|
FROM python:3.12-slim AS base
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
RUN apt-get update \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
curl \
|
||||||
|
git \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir uv
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ENV UV_PROJECT_ENVIRONMENT=/app/.venv
|
||||||
|
ENV UV_PIP_INDEX_URL=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||||
|
ENV HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
ENV HF_HOME=/app/.cache/huggingface
|
||||||
|
ENV TRANSFORMERS_CACHE=/app/.cache/huggingface/transformers
|
||||||
|
ENV HF_DATASETS_CACHE=/app/.cache/huggingface/datasets
|
||||||
|
ENV EMBEDDING_ATLAS_DEVICE=cpu
|
||||||
|
|
||||||
|
RUN mkdir -p "$HF_HOME" "$TRANSFORMERS_CACHE" "$HF_DATASETS_CACHE"
|
||||||
|
VOLUME ["/app/.cache/huggingface"]
|
||||||
|
|
||||||
|
COPY pyproject.toml README.md ./
|
||||||
|
COPY src ./src
|
||||||
|
COPY script ./script
|
||||||
|
|
||||||
|
# Install dependencies with uv. When a lock file is present it will be used automatically.
|
||||||
|
RUN uv sync --no-dev || uv sync --no-dev
|
||||||
|
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
ENV EMBEDDING_API_HOST=0.0.0.0
|
||||||
|
ENV EMBEDDING_API_PORT=9000
|
||||||
|
|
||||||
|
EXPOSE 9000
|
||||||
|
|
||||||
39
docker/docker-compose.yml
Normal file
39
docker/docker-compose.yml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
version: "3.9"
|
||||||
|
|
||||||
|
services:
|
||||||
|
engine:
|
||||||
|
image: docker:25.0-dind
|
||||||
|
privileged: true
|
||||||
|
environment:
|
||||||
|
- DOCKER_TLS_CERTDIR=
|
||||||
|
command: ["--host=tcp://0.0.0.0:2375"]
|
||||||
|
ports:
|
||||||
|
- "2375:2375"
|
||||||
|
volumes:
|
||||||
|
- engine-data:/var/lib/docker
|
||||||
|
- sessions-data:/sessions
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
orchestrator:
|
||||||
|
# Ensure the embedding-atlas image exists (e.g. docker build -f docker/embedding-atlas.Dockerfile -t embedding-atlas:latest ..)
|
||||||
|
build:
|
||||||
|
context: ..
|
||||||
|
dockerfile: docker/Dockerfile
|
||||||
|
depends_on:
|
||||||
|
- engine
|
||||||
|
environment:
|
||||||
|
- EMBEDDING_DOCKER_URL=tcp://engine:2375
|
||||||
|
- EMBEDDING_CONTAINER_IMAGE=embedding-atlas:latest
|
||||||
|
- EMBEDDING_CONTAINER_NAME_PREFIX=embedding-atlas
|
||||||
|
- EMBEDDING_API_HOST=0.0.0.0
|
||||||
|
- EMBEDDING_API_PORT=9000
|
||||||
|
- EMBEDDING_SESSION_ROOT=/sessions
|
||||||
|
volumes:
|
||||||
|
- sessions-data:/sessions
|
||||||
|
ports:
|
||||||
|
- "9000:9000"
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
engine-data:
|
||||||
|
sessions-data:
|
||||||
17
docker/embedding-atlas.Dockerfile
Normal file
17
docker/embedding-atlas.Dockerfile
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
RUN apt-get update \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir embedding-atlas
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
EXPOSE 5055
|
||||||
|
|
||||||
|
CMD ["embedding-atlas"]
|
||||||
11
fastmcp.json
Normal file
11
fastmcp.json
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"embedding-atlas-session-manager": {
|
||||||
|
"command": "embedding-backend-mcp",
|
||||||
|
"args": [],
|
||||||
|
"env": {},
|
||||||
|
"timeout": 30000,
|
||||||
|
"description": "Start the embedding orchestrator in MCP stdio mode"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
10
pixi.toml
Normal file
10
pixi.toml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
[workspace]
|
||||||
|
authors = ["lingyuzeng <pylyzeng@gmail.com>"]
|
||||||
|
channels = ["conda-forge"]
|
||||||
|
name = "embedding_atlas"
|
||||||
|
platforms = ["osx-arm64"]
|
||||||
|
version = "0.1.0"
|
||||||
|
|
||||||
|
[tasks]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "mole-embedding-atlas"
|
name = "mole-embedding-atlas"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -12,4 +13,31 @@ dependencies = [
|
|||||||
"rdkit",
|
"rdkit",
|
||||||
"pandas",
|
"pandas",
|
||||||
"selfies==2.1.1",
|
"selfies==2.1.1",
|
||||||
|
"fastapi>=0.111",
|
||||||
|
"uvicorn[standard]>=0.29",
|
||||||
|
"fastmcp>=2.11",
|
||||||
|
"docker>=7.1",
|
||||||
|
"httpx>=0.27",
|
||||||
|
"pydantic-settings>=2.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
embedding-backend = "embedding_backend.main:main"
|
||||||
|
embedding-backend-api = "embedding_backend.main:run_api"
|
||||||
|
embedding-backend-mcp = "embedding_backend.main:run_mcp_stdio"
|
||||||
|
|
||||||
|
[tool.uv.pip]
|
||||||
|
index-url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
package-dir = {"" = "src"}
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=68", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
package = true
|
||||||
|
|||||||
186
script/add_ecfp4_tanimoto.py
Normal file
186
script/add_ecfp4_tanimoto.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Augment a CSV with ECFP4 binary fingerprints and Tanimoto neighbor summaries."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Iterable, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from rdkit import Chem, DataStructs
|
||||||
|
from rdkit.Chem import rdFingerprintGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ECFP4Generator:
|
||||||
|
"""Generate ECFP4 (Morgan radius 2) fingerprints as RDKit bit vectors."""
|
||||||
|
|
||||||
|
n_bits: int = 2048
|
||||||
|
radius: int = 2
|
||||||
|
include_chirality: bool = True
|
||||||
|
generator: rdFingerprintGenerator.MorganGenerator = field(init=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.generator = rdFingerprintGenerator.GetMorganGenerator(
|
||||||
|
radius=self.radius,
|
||||||
|
fpSize=self.n_bits,
|
||||||
|
includeChirality=self.include_chirality,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fingerprint(self, smiles: str) -> Optional[DataStructs.ExplicitBitVect]:
|
||||||
|
if not smiles:
|
||||||
|
return None
|
||||||
|
|
||||||
|
mol = Chem.MolFromSmiles(smiles)
|
||||||
|
if mol is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.generator.GetFingerprint(mol)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_binary_string(self, fp: DataStructs.ExplicitBitVect) -> str:
|
||||||
|
arr = np.zeros((self.n_bits,), dtype=np.uint8)
|
||||||
|
DataStructs.ConvertToNumpyArray(fp, arr)
|
||||||
|
return ''.join(arr.astype(str))
|
||||||
|
|
||||||
|
|
||||||
|
def tanimoto_top_k(
|
||||||
|
fingerprints: Sequence[Optional[DataStructs.ExplicitBitVect]],
|
||||||
|
ids: Sequence[str],
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Return semicolon-delimited Tanimoto summaries for each fingerprint."""
|
||||||
|
valid_indices = [idx for idx, fp in enumerate(fingerprints) if fp is not None]
|
||||||
|
valid_fps = [fingerprints[idx] for idx in valid_indices]
|
||||||
|
|
||||||
|
index_lookup = {pos: original for pos, original in enumerate(valid_indices)}
|
||||||
|
summaries = [''] * len(fingerprints)
|
||||||
|
|
||||||
|
if not valid_fps:
|
||||||
|
return summaries
|
||||||
|
|
||||||
|
for pos, fp in enumerate(valid_fps):
|
||||||
|
sims = DataStructs.BulkTanimotoSimilarity(fp, valid_fps)
|
||||||
|
ranked: List[Tuple[int, float]] = []
|
||||||
|
for other_pos, score in enumerate(sims):
|
||||||
|
if other_pos == pos:
|
||||||
|
continue
|
||||||
|
ranked.append((other_pos, score))
|
||||||
|
|
||||||
|
ranked.sort(key=lambda item: item[1], reverse=True)
|
||||||
|
top_entries = []
|
||||||
|
for other_pos, score in ranked[:top_k]:
|
||||||
|
original_idx = index_lookup[other_pos]
|
||||||
|
top_entries.append(f"{ids[original_idx]}:{score:.4f}")
|
||||||
|
|
||||||
|
summaries[index_lookup[pos]] = ';'.join(top_entries)
|
||||||
|
|
||||||
|
return summaries
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument("input_csv", type=pathlib.Path, help="Source CSV with SMILES")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=pathlib.Path,
|
||||||
|
default=None,
|
||||||
|
help="Destination file (default: <input>_with_ecfp4.parquet)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--smiles-column",
|
||||||
|
default="smiles",
|
||||||
|
help="Name of the column containing SMILES (default: smiles)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--id-column",
|
||||||
|
default="generated_id",
|
||||||
|
help="Column used to label Tanimoto neighbors (default: generated_id)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of nearest neighbors to report in Tanimoto summaries (default: 5)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--format",
|
||||||
|
choices=("parquet", "csv", "auto"),
|
||||||
|
default="parquet",
|
||||||
|
help="Output format; defaults to parquet unless overridden or inferred from --output",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if not args.input_csv.exists():
|
||||||
|
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
|
||||||
|
|
||||||
|
def resolve_output_path() -> pathlib.Path:
|
||||||
|
if args.output is not None:
|
||||||
|
return args.output
|
||||||
|
|
||||||
|
suffix = ".parquet" if args.format in ("parquet", "auto") else ".csv"
|
||||||
|
return args.input_csv.with_name(f"{args.input_csv.stem}_with_ecfp4{suffix}")
|
||||||
|
|
||||||
|
def resolve_format(path: pathlib.Path) -> str:
|
||||||
|
suffix = path.suffix.lower()
|
||||||
|
if suffix in {".parquet", ".pq"}:
|
||||||
|
return "parquet"
|
||||||
|
if suffix == ".csv":
|
||||||
|
return "csv"
|
||||||
|
if args.format == "parquet":
|
||||||
|
return "parquet"
|
||||||
|
if args.format == "csv":
|
||||||
|
return "csv"
|
||||||
|
raise ValueError(
|
||||||
|
"无法根据输出文件推断格式,请为 --output 指定 .parquet/.csv 后缀或使用 --format",
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path = resolve_output_path()
|
||||||
|
output_format = resolve_format(output_path)
|
||||||
|
|
||||||
|
if args.input_csv.resolve() == output_path.resolve():
|
||||||
|
raise ValueError("Output path must differ from input path to avoid overwriting input.")
|
||||||
|
|
||||||
|
df = pd.read_csv(args.input_csv)
|
||||||
|
if args.smiles_column not in df.columns:
|
||||||
|
raise ValueError(f"Column '{args.smiles_column}' not found in input data")
|
||||||
|
|
||||||
|
smiles_series = df[args.smiles_column].fillna('')
|
||||||
|
|
||||||
|
if args.id_column in df.columns:
|
||||||
|
ids = df[args.id_column].astype(str).tolist()
|
||||||
|
else:
|
||||||
|
ids = [f"D{idx:06d}" for idx in range(1, len(df) + 1)]
|
||||||
|
df[args.id_column] = ids
|
||||||
|
|
||||||
|
generator = ECFP4Generator()
|
||||||
|
fingerprints: List[Optional[DataStructs.ExplicitBitVect]] = []
|
||||||
|
binary_repr: List[str] = []
|
||||||
|
|
||||||
|
for smiles in smiles_series:
|
||||||
|
fp = generator.fingerprint(smiles)
|
||||||
|
fingerprints.append(fp)
|
||||||
|
binary_repr.append(generator.to_binary_string(fp) if fp is not None else '')
|
||||||
|
|
||||||
|
df['ecfp4_binary'] = binary_repr
|
||||||
|
df['tanimoto_top_neighbors'] = tanimoto_top_k(fingerprints, ids, top_k=args.top_k)
|
||||||
|
|
||||||
|
if output_format == "parquet":
|
||||||
|
df.to_parquet(output_path, index=False)
|
||||||
|
else:
|
||||||
|
df.to_csv(output_path, index=False)
|
||||||
|
|
||||||
|
print(f"Wrote augmented data with {len(df)} rows to {output_path} ({output_format})")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
127
script/add_macrocycle_columns.py
Normal file
127
script/add_macrocycle_columns.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Add generated IDs and macrolactone ring size annotations to a CSV file."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from rdkit import Chem
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MacrolactoneRingDetector:
|
||||||
|
"""Detect macrolactone rings in SMILES strings via SMARTS patterns."""
|
||||||
|
|
||||||
|
min_size: int = 12
|
||||||
|
max_size: int = 20
|
||||||
|
patterns: Dict[int, Optional[Chem.Mol]] = field(init=False, default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.patterns = {
|
||||||
|
size: Chem.MolFromSmarts(f"[r{size}]([#8][#6](=[#8]))")
|
||||||
|
for size in range(self.min_size, self.max_size + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
def ring_sizes(self, smiles: str) -> List[int]:
|
||||||
|
"""Return a sorted list of macrolactone ring sizes present in the SMILES."""
|
||||||
|
if not smiles:
|
||||||
|
return []
|
||||||
|
|
||||||
|
mol = Chem.MolFromSmiles(smiles)
|
||||||
|
if mol is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
ring_atoms = mol.GetRingInfo().AtomRings()
|
||||||
|
if not ring_atoms:
|
||||||
|
return []
|
||||||
|
|
||||||
|
matched_rings: Set[Tuple[int, ...]] = set()
|
||||||
|
matched_sizes: Set[int] = set()
|
||||||
|
|
||||||
|
for size, query in self.patterns.items():
|
||||||
|
if query is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for match in mol.GetSubstructMatches(query, uniquify=True):
|
||||||
|
ring = self._pick_ring(match, ring_atoms, size)
|
||||||
|
if ring and ring not in matched_rings:
|
||||||
|
matched_rings.add(ring)
|
||||||
|
matched_sizes.add(size)
|
||||||
|
|
||||||
|
sizes = sorted(matched_sizes)
|
||||||
|
|
||||||
|
if len(sizes) > 1:
|
||||||
|
warnings.warn(
|
||||||
|
"Multiple macrolactone ring sizes detected",
|
||||||
|
RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
print(f"Multiple ring sizes {sizes} for SMILES: {smiles}")
|
||||||
|
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pick_ring(
|
||||||
|
match: Tuple[int, ...], rings: Iterable[Tuple[int, ...]], expected_size: int
|
||||||
|
) -> Optional[Tuple[int, ...]]:
|
||||||
|
ring_atom = match[0]
|
||||||
|
for ring in rings:
|
||||||
|
if len(ring) == expected_size and ring_atom in ring:
|
||||||
|
return tuple(sorted(ring))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def add_columns(df: pd.DataFrame, detector: MacrolactoneRingDetector) -> pd.DataFrame:
|
||||||
|
result = df.copy()
|
||||||
|
result["generated_id"] = [f"D{index:06d}" for index in range(1, len(result) + 1)]
|
||||||
|
|
||||||
|
smiles_series = result["smiles"].fillna("") if "smiles" in result.columns else pd.Series(
|
||||||
|
[""] * len(result), index=result.index
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_sizes(smiles: str) -> str:
|
||||||
|
sizes = detector.ring_sizes(smiles)
|
||||||
|
return ";".join(str(size) for size in sizes) if sizes else ""
|
||||||
|
|
||||||
|
result["macrocycle_ring_sizes"] = smiles_series.apply(format_sizes)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument("input_csv", type=pathlib.Path, help="Path to the source CSV file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=pathlib.Path,
|
||||||
|
default=None,
|
||||||
|
help="Destination for the augmented CSV (default: <input>_with_macrocycles.csv)",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if not args.input_csv.exists():
|
||||||
|
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
|
||||||
|
|
||||||
|
output_path = args.output or args.input_csv.with_name(
|
||||||
|
f"{args.input_csv.stem}_with_macrocycles.csv"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.input_csv.resolve() == output_path.resolve():
|
||||||
|
raise ValueError("Output path must differ from input path to avoid overwriting.")
|
||||||
|
|
||||||
|
df = pd.read_csv(args.input_csv)
|
||||||
|
detector = MacrolactoneRingDetector()
|
||||||
|
augmented = add_columns(df, detector)
|
||||||
|
augmented.to_csv(output_path, index=False)
|
||||||
|
|
||||||
|
print(f"Wrote augmented data with {len(augmented)} rows to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
439
script/ecfp4_umap_embedding_optimized.py
Normal file
439
script/ecfp4_umap_embedding_optimized.py
Normal file
@@ -0,0 +1,439 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Optimized ECFP4 Fingerprinting with UMAP Visualization for Macrolactone Molecules
|
||||||
|
|
||||||
|
This script processes SMILES data to:
|
||||||
|
1. Generate ECFP4 fingerprints using RDKit
|
||||||
|
2. Detect ring numbers in macrolactone molecules using SMARTS patterns
|
||||||
|
3. Generate unique IDs for molecules without existing IDs
|
||||||
|
4. Perform UMAP dimensionality reduction with Tanimoto distance
|
||||||
|
5. Prepare data for embedding-atlas visualization
|
||||||
|
|
||||||
|
Optimized for large datasets with progress tracking and memory efficiency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
# RDKit imports
|
||||||
|
from rdkit import Chem
|
||||||
|
from rdkit.Chem import rdMolDescriptors, DataStructs
|
||||||
|
from rdkit.Chem.MolStandardize import rdMolStandardize
|
||||||
|
|
||||||
|
# Data processing
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# UMAP and visualization
|
||||||
|
import umap
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# Suppress warnings
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
# Progress bar
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
HAS_TQDM = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TQDM = False
|
||||||
|
|
||||||
|
class MacrolactoneProcessor:
|
||||||
|
"""Process macrolactone molecules for embedding visualization."""
|
||||||
|
|
||||||
|
def __init__(self, n_bits: int = 2048, radius: int = 2, chirality: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize processor with ECFP4 parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_bits: Number of fingerprint bits (default: 2048)
|
||||||
|
radius: Morgan fingerprint radius (default: 2 for ECFP4)
|
||||||
|
chirality: Include chirality information (default: True)
|
||||||
|
"""
|
||||||
|
self.n_bits = n_bits
|
||||||
|
self.radius = radius
|
||||||
|
self.chirality = chirality
|
||||||
|
|
||||||
|
# Standardizer for molecule preprocessing
|
||||||
|
self.standardizer = rdMolStandardize.MetalDisconnector()
|
||||||
|
|
||||||
|
# SMARTS patterns for different ring sizes (12-20 membered rings)
|
||||||
|
self.ring_smarts = {
|
||||||
|
12: '[r12][#8][#6](=[#8])', # 12-membered ring with lactone
|
||||||
|
13: '[r13][#8][#6](=[#8])', # 13-membered ring with lactone
|
||||||
|
14: '[r14][#8][#6](=[#8])', # 14-membered ring with lactone
|
||||||
|
15: '[r15][#8][#6](=[#8])', # 15-membered ring with lactone
|
||||||
|
16: '[r16][#8][#6](=[#8])', # 16-membered ring with lactone
|
||||||
|
17: '[r17][#8][#6](=[#8])', # 17-membered ring with lactone
|
||||||
|
18: '[r18][#8][#6](=[#8])', # 18-membered ring with lactone
|
||||||
|
19: '[r19][#8][#6](=[#8])', # 19-membered ring with lactone
|
||||||
|
20: '[r20][#8][#6](=[#8])', # 20-membered ring with lactone
|
||||||
|
}
|
||||||
|
|
||||||
|
def standardize_molecule(self, mol: Chem.Mol) -> Optional[Chem.Mol]:
|
||||||
|
"""Standardize molecule using RDKit standardization."""
|
||||||
|
try:
|
||||||
|
# Remove metals
|
||||||
|
mol = self.standardizer.Disconnect(mol)
|
||||||
|
# Normalize
|
||||||
|
mol = rdMolStandardize.Normalize(mol)
|
||||||
|
# Remove fragments
|
||||||
|
mol = rdMolStandardize.FragmentParent(mol)
|
||||||
|
# Neutralize charges
|
||||||
|
mol = rdMolStandardize.ChargeParent(mol)
|
||||||
|
return mol
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def ecfp4_fingerprint(self, smiles: str) -> Optional[np.ndarray]:
|
||||||
|
"""Generate ECFP4 fingerprint from SMILES string using newer RDKit API."""
|
||||||
|
try:
|
||||||
|
mol = Chem.MolFromSmiles(smiles)
|
||||||
|
if mol is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Standardize molecule
|
||||||
|
mol = self.standardize_molecule(mol)
|
||||||
|
if mol is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Generate Morgan fingerprint using the newer API to avoid deprecation warnings
|
||||||
|
from rdkit.Chem import rdFingerprintGenerator
|
||||||
|
generator = rdFingerprintGenerator.GetMorganGenerator(
|
||||||
|
radius=self.radius,
|
||||||
|
fpSize=self.n_bits,
|
||||||
|
includeChirality=self.chirality
|
||||||
|
)
|
||||||
|
bv = generator.GetFingerprint(mol)
|
||||||
|
|
||||||
|
# Convert to numpy array
|
||||||
|
arr = np.zeros((self.n_bits,), dtype=np.uint8)
|
||||||
|
DataStructs.ConvertToNumpyArray(bv, arr)
|
||||||
|
return arr
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing SMILES {smiles[:50]}...: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def detect_ring_number(self, smiles: str) -> int:
|
||||||
|
"""Detect the ring number in macrolactone molecule using SMARTS patterns."""
|
||||||
|
try:
|
||||||
|
mol = Chem.MolFromSmiles(smiles)
|
||||||
|
if mol is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Check each ring size pattern
|
||||||
|
for ring_size, smarts in self.ring_smarts.items():
|
||||||
|
query = Chem.MolFromSmarts(smarts)
|
||||||
|
if query:
|
||||||
|
matches = mol.GetSubstructMatches(query)
|
||||||
|
if matches:
|
||||||
|
return ring_size
|
||||||
|
|
||||||
|
# Alternative: check for any large ring with lactone
|
||||||
|
generic_pattern = Chem.MolFromSmarts('[r{12-20}][#8][#6](=[#8])')
|
||||||
|
if generic_pattern:
|
||||||
|
matches = mol.GetSubstructMatches(generic_pattern)
|
||||||
|
if matches:
|
||||||
|
# Try to determine ring size from the first match
|
||||||
|
for match in matches:
|
||||||
|
# Get the ring atoms
|
||||||
|
for atom_idx in match:
|
||||||
|
atom = mol.GetAtomWithIdx(atom_idx)
|
||||||
|
if atom.IsInRing():
|
||||||
|
# Find the ring size
|
||||||
|
for ring in atom.GetOwningMol().GetRingInfo().AtomRings():
|
||||||
|
if atom_idx in ring:
|
||||||
|
ring_size = len(ring)
|
||||||
|
if 12 <= ring_size <= 20:
|
||||||
|
return ring_size
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error detecting ring number for {smiles}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def generate_unique_id(self, index: int, existing_id: Optional[str] = None) -> str:
|
||||||
|
"""Generate unique ID for molecule."""
|
||||||
|
if existing_id and pd.notna(existing_id) and existing_id != '':
|
||||||
|
return str(existing_id)
|
||||||
|
else:
|
||||||
|
return f"D{index:07d}"
|
||||||
|
|
||||||
|
def tanimoto_similarity(self, fp1: np.ndarray, fp2: np.ndarray) -> float:
|
||||||
|
"""Calculate Tanimoto similarity between two fingerprints."""
|
||||||
|
# Bit count
|
||||||
|
bit_count1 = np.sum(fp1)
|
||||||
|
bit_count2 = np.sum(fp2)
|
||||||
|
common_bits = np.sum(fp1 & fp2)
|
||||||
|
|
||||||
|
if bit_count1 + bit_count2 - common_bits == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return common_bits / (bit_count1 + bit_count2 - common_bits)
|
||||||
|
|
||||||
|
def find_neighbors(self, X: np.ndarray, k: int = 15, batch_size: int = 1000) -> List[str]:
|
||||||
|
"""Find k nearest neighbors for each molecule based on Tanimoto similarity."""
|
||||||
|
n_samples = X.shape[0]
|
||||||
|
neighbors = []
|
||||||
|
|
||||||
|
# Progress bar
|
||||||
|
if HAS_TQDM:
|
||||||
|
pbar = tqdm(total=n_samples, desc="Finding neighbors")
|
||||||
|
|
||||||
|
for i in range(n_samples):
|
||||||
|
similarities = []
|
||||||
|
|
||||||
|
# Batch processing for memory efficiency
|
||||||
|
for j in range(0, n_samples, batch_size):
|
||||||
|
end_j = min(j + batch_size, n_samples)
|
||||||
|
batch_X = X[j:end_j]
|
||||||
|
|
||||||
|
# Calculate similarities for this batch
|
||||||
|
for batch_idx, fp in enumerate(batch_X):
|
||||||
|
orig_idx = j + batch_idx
|
||||||
|
if i != orig_idx:
|
||||||
|
sim = self.tanimoto_similarity(X[i], fp)
|
||||||
|
similarities.append((orig_idx, sim))
|
||||||
|
|
||||||
|
# Sort by similarity (descending)
|
||||||
|
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# Get top k neighbors
|
||||||
|
top_neighbors = [str(idx) for idx, _ in similarities[:k]]
|
||||||
|
neighbors.append(','.join(top_neighbors))
|
||||||
|
|
||||||
|
if HAS_TQDM:
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
if HAS_TQDM:
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
return neighbors
|
||||||
|
|
||||||
|
def perform_umap(self, X: np.ndarray, n_neighbors: int = 30,
|
||||||
|
min_dist: float = 0.1, metric: str = 'jaccard') -> np.ndarray:
|
||||||
|
"""Perform UMAP dimensionality reduction."""
|
||||||
|
reducer = umap.UMAP(
|
||||||
|
n_neighbors=n_neighbors,
|
||||||
|
min_dist=min_dist,
|
||||||
|
metric=metric,
|
||||||
|
random_state=42
|
||||||
|
)
|
||||||
|
|
||||||
|
return reducer.fit_transform(X)
|
||||||
|
|
||||||
|
def process_dataframe(self, df: pd.DataFrame, smiles_col: str = 'smiles',
|
||||||
|
id_col: Optional[str] = None, max_molecules: Optional[int] = None) -> pd.DataFrame:
|
||||||
|
"""Process dataframe with SMILES strings."""
|
||||||
|
print(f"Processing {len(df)} molecules...")
|
||||||
|
|
||||||
|
# Limit molecules if requested
|
||||||
|
if max_molecules:
|
||||||
|
df = df.head(max_molecules)
|
||||||
|
print(f"Limited to {max_molecules} molecules")
|
||||||
|
|
||||||
|
# Ensure we have a smiles column
|
||||||
|
if smiles_col not in df.columns:
|
||||||
|
raise ValueError(f"Column '{smiles_col}' not found in dataframe")
|
||||||
|
|
||||||
|
# Create a working copy
|
||||||
|
result_df = df.copy()
|
||||||
|
|
||||||
|
# Generate unique IDs if needed
|
||||||
|
if id_col and id_col in df.columns:
|
||||||
|
result_df['molecule_id'] = [self.generate_unique_id(i, existing_id)
|
||||||
|
for i, existing_id in enumerate(result_df[id_col])]
|
||||||
|
else:
|
||||||
|
result_df['molecule_id'] = [self.generate_unique_id(i)
|
||||||
|
for i in range(len(result_df))]
|
||||||
|
|
||||||
|
# Process fingerprints
|
||||||
|
print("Generating ECFP4 fingerprints...")
|
||||||
|
fingerprints = []
|
||||||
|
valid_indices = []
|
||||||
|
|
||||||
|
# Progress tracking
|
||||||
|
iterator = enumerate(result_df[smiles_col])
|
||||||
|
if HAS_TQDM:
|
||||||
|
iterator = tqdm(iterator, total=len(result_df), desc="Processing fingerprints")
|
||||||
|
|
||||||
|
for idx, smiles in iterator:
|
||||||
|
if pd.notna(smiles) and smiles != '':
|
||||||
|
fp = self.ecfp4_fingerprint(smiles)
|
||||||
|
if fp is not None:
|
||||||
|
fingerprints.append(fp)
|
||||||
|
valid_indices.append(idx)
|
||||||
|
else:
|
||||||
|
print(f"Failed to generate fingerprint for index {idx}: {smiles[:50]}...")
|
||||||
|
else:
|
||||||
|
print(f"Invalid SMILES at index {idx}")
|
||||||
|
|
||||||
|
# Filter dataframe to valid molecules only
|
||||||
|
result_df = result_df.iloc[valid_indices].reset_index(drop=True)
|
||||||
|
|
||||||
|
if not fingerprints:
|
||||||
|
raise ValueError("No valid fingerprints generated")
|
||||||
|
|
||||||
|
# Convert fingerprints to numpy array
|
||||||
|
X = np.array(fingerprints)
|
||||||
|
print(f"Generated fingerprints for {len(fingerprints)} molecules")
|
||||||
|
|
||||||
|
# Detect ring numbers
|
||||||
|
print("Detecting ring numbers...")
|
||||||
|
ring_numbers = []
|
||||||
|
|
||||||
|
iterator = result_df[smiles_col]
|
||||||
|
if HAS_TQDM:
|
||||||
|
iterator = tqdm(iterator, desc="Detecting rings")
|
||||||
|
|
||||||
|
for smiles in iterator:
|
||||||
|
ring_num = self.detect_ring_number(smiles)
|
||||||
|
ring_numbers.append(ring_num)
|
||||||
|
|
||||||
|
result_df['ring_num'] = ring_numbers
|
||||||
|
|
||||||
|
# Perform UMAP
|
||||||
|
print("Performing UMAP dimensionality reduction...")
|
||||||
|
embedding = self.perform_umap(X)
|
||||||
|
result_df['projection_x'] = embedding[:, 0]
|
||||||
|
result_df['projection_y'] = embedding[:, 1]
|
||||||
|
|
||||||
|
# Find neighbors for embedding-atlas
|
||||||
|
print("Finding nearest neighbors...")
|
||||||
|
neighbors = self.find_neighbors(X, k=15)
|
||||||
|
result_df['neighbors'] = neighbors
|
||||||
|
|
||||||
|
# Add fingerprint information
|
||||||
|
result_df['fingerprint_bits'] = [fp.tolist() for fp in fingerprints]
|
||||||
|
|
||||||
|
return result_df
|
||||||
|
|
||||||
|
def create_visualization(self, df: pd.DataFrame, output_path: str):
|
||||||
|
"""Create visualization of the UMAP embedding."""
|
||||||
|
plt.figure(figsize=(12, 8))
|
||||||
|
|
||||||
|
# Color by ring number
|
||||||
|
scatter = plt.scatter(df['projection_x'], df['projection_y'],
|
||||||
|
c=df['ring_num'], cmap='viridis', alpha=0.6, s=30)
|
||||||
|
|
||||||
|
plt.colorbar(scatter, label='Ring Number')
|
||||||
|
plt.xlabel('UMAP 1')
|
||||||
|
plt.ylabel('UMAP 2')
|
||||||
|
plt.title('Macrolactone Molecules - ECFP4 + UMAP Visualization')
|
||||||
|
|
||||||
|
# Add some annotations for ring numbers
|
||||||
|
for ring_num in sorted(df['ring_num'].unique()):
|
||||||
|
if ring_num > 0:
|
||||||
|
subset = df[df['ring_num'] == ring_num]
|
||||||
|
if len(subset) > 0:
|
||||||
|
center_x = subset['projection_x'].mean()
|
||||||
|
center_y = subset['projection_y'].mean()
|
||||||
|
plt.annotate(f'{ring_num} ring', (center_x, center_y),
|
||||||
|
fontsize=10, fontweight='bold')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
print(f"Visualization saved to {output_path}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to run the processing pipeline."""
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='ECFP4 + UMAP for Macrolactone Molecules')
|
||||||
|
parser.add_argument('--input', '-i', required=True,
|
||||||
|
help='Input CSV file path')
|
||||||
|
parser.add_argument('--output', '-o', required=True,
|
||||||
|
help='Output CSV file path')
|
||||||
|
parser.add_argument('--smiles-col', default='smiles',
|
||||||
|
help='Name of SMILES column (default: smiles)')
|
||||||
|
parser.add_argument('--id-col', default=None,
|
||||||
|
help='Name of ID column (optional)')
|
||||||
|
parser.add_argument('--visualization', '-v', default='umap_visualization.png',
|
||||||
|
help='Output visualization file path')
|
||||||
|
parser.add_argument('--max-molecules', type=int, default=None,
|
||||||
|
help='Maximum number of molecules to process (for testing)')
|
||||||
|
parser.add_argument('--launch-atlas', action='store_true',
|
||||||
|
help='Launch embedding-atlas process')
|
||||||
|
parser.add_argument('--atlas-port', type=int, default=8080,
|
||||||
|
help='Port for embedding-atlas server')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Initialize processor
|
||||||
|
processor = MacrolactoneProcessor(n_bits=2048, radius=2, chirality=True)
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
print(f"Loading data from {args.input}")
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(args.input)
|
||||||
|
print(f"Loaded {len(df)} molecules")
|
||||||
|
print(f"Columns: {list(df.columns)}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading data: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Process dataframe
|
||||||
|
try:
|
||||||
|
processed_df = processor.process_dataframe(df,
|
||||||
|
smiles_col=args.smiles_col,
|
||||||
|
id_col=args.id_col,
|
||||||
|
max_molecules=args.max_molecules)
|
||||||
|
print(f"Successfully processed {len(processed_df)} molecules")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing data: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
try:
|
||||||
|
processed_df.to_csv(args.output, index=False)
|
||||||
|
print(f"Results saved to {args.output}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving results: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Create visualization
|
||||||
|
try:
|
||||||
|
processor.create_visualization(processed_df, args.visualization)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error creating visualization: {e}")
|
||||||
|
|
||||||
|
# Launch embedding-atlas if requested
|
||||||
|
if args.launch_atlas:
|
||||||
|
print("Launching embedding-atlas process...")
|
||||||
|
try:
|
||||||
|
# Prepare command for embedding-atlas
|
||||||
|
cmd = [
|
||||||
|
'embedding-atlas', 'data', args.output,
|
||||||
|
'--text', args.smiles_col,
|
||||||
|
'--port', str(args.atlas_port),
|
||||||
|
'--neighbors', 'neighbors',
|
||||||
|
'--x', 'projection_x',
|
||||||
|
'--y', 'projection_y'
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Running command: {' '.join(cmd)}")
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
print("Embedding-atlas process launched successfully")
|
||||||
|
print(f"Access the visualization at: http://localhost:{args.atlas_port}")
|
||||||
|
else:
|
||||||
|
print(f"Error launching embedding-atlas: {result.stderr}")
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("embedding-atlas command not found. Please install it first.")
|
||||||
|
print("You can install it with: pip install embedding-atlas")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error launching embedding-atlas: {e}")
|
||||||
|
|
||||||
|
print("Processing complete!")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
sys.exit(main())
|
||||||
5
src/embedding_backend/__init__.py
Normal file
5
src/embedding_backend/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Embedding Atlas session orchestrator package."""
|
||||||
|
|
||||||
|
from .main import create_app, get_mcp_server
|
||||||
|
|
||||||
|
__all__ = ["create_app", "get_mcp_server"]
|
||||||
90
src/embedding_backend/config.py
Normal file
90
src/embedding_backend/config.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Configuration for the embedding session orchestrator."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from the environment."""
|
||||||
|
|
||||||
|
docker_url: str = Field(
|
||||||
|
default="unix:///var/run/docker.sock",
|
||||||
|
description="Docker socket URL accessed by the orchestrator.",
|
||||||
|
)
|
||||||
|
container_image: str = Field(
|
||||||
|
default="embedding-atlas:latest",
|
||||||
|
description="Default Docker image used for embedding-atlas sessions.",
|
||||||
|
)
|
||||||
|
container_name_prefix: str = Field(
|
||||||
|
default="embedding-atlas",
|
||||||
|
description="Prefix applied to managed container names.",
|
||||||
|
)
|
||||||
|
container_host: str = Field(
|
||||||
|
default="0.0.0.0",
|
||||||
|
description="Host binding used inside managed containers.",
|
||||||
|
)
|
||||||
|
api_host: str = Field(
|
||||||
|
default="0.0.0.0",
|
||||||
|
description="Listening host for the FastAPI server.",
|
||||||
|
)
|
||||||
|
api_port: int = Field(
|
||||||
|
default=9000,
|
||||||
|
description="Listening port for the FastAPI server.",
|
||||||
|
)
|
||||||
|
default_port: int = Field(
|
||||||
|
default=5055,
|
||||||
|
description="Fallback embedding-atlas port when no override is provided.",
|
||||||
|
)
|
||||||
|
port_range_start: int = Field(
|
||||||
|
default=6000,
|
||||||
|
description="Lower bound for dynamically assigned ports.",
|
||||||
|
)
|
||||||
|
port_range_end: int = Field(
|
||||||
|
default=6999,
|
||||||
|
description="Upper bound for dynamically assigned ports (inclusive).",
|
||||||
|
)
|
||||||
|
session_root: Path = Field(
|
||||||
|
default=Path("runtime/sessions"),
|
||||||
|
description="Directory used to persist downloaded datasets per session.",
|
||||||
|
)
|
||||||
|
download_chunk_size: int = Field(
|
||||||
|
default=64 * 1024,
|
||||||
|
description="Chunk size in bytes when streaming remote datasets.",
|
||||||
|
)
|
||||||
|
download_timeout_seconds: int = Field(
|
||||||
|
default=300,
|
||||||
|
description="Timeout applied to dataset downloads.",
|
||||||
|
)
|
||||||
|
auto_remove_seconds: int = Field(
|
||||||
|
default=36000,
|
||||||
|
description="Default inactivity window before a session is garbage-collected (10h).",
|
||||||
|
)
|
||||||
|
cleanup_interval_seconds: int = Field(
|
||||||
|
default=600,
|
||||||
|
description="Background cleanup cadence in seconds.",
|
||||||
|
)
|
||||||
|
container_network: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional Docker network to attach managed containers to.",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_prefix="EMBEDDING_",
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare(self) -> None:
|
||||||
|
"""Ensure directories referenced by the settings exist."""
|
||||||
|
|
||||||
|
self.session_root = self.session_root.resolve()
|
||||||
|
self.session_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
settings.prepare()
|
||||||
376
src/embedding_backend/docker_manager.py
Normal file
376
src/embedding_backend/docker_manager.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
"""Asynchronous session manager backed by Docker containers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
from contextlib import suppress
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Iterable, Optional
|
||||||
|
|
||||||
|
import docker
|
||||||
|
import httpx
|
||||||
|
from docker.errors import APIError, NotFound
|
||||||
|
from docker.models.containers import Container
|
||||||
|
|
||||||
|
from .config import Settings, settings
|
||||||
|
from .models import (
|
||||||
|
DATASET_LABEL,
|
||||||
|
EXPIRY_LABEL,
|
||||||
|
MANAGED_LABEL,
|
||||||
|
SESSION_ID_LABEL,
|
||||||
|
SessionCreateRequest,
|
||||||
|
SessionRecord,
|
||||||
|
default_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionError(RuntimeError):
|
||||||
|
"""Base class for orchestration failures."""
|
||||||
|
|
||||||
|
|
||||||
|
class SessionExistsError(SessionError):
|
||||||
|
"""Raised when a duplicate session identifier is requested."""
|
||||||
|
|
||||||
|
|
||||||
|
class SessionNotFoundError(SessionError):
|
||||||
|
"""Raised when attempting to operate on a missing session."""
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetDownloadError(SessionError):
|
||||||
|
"""Raised when the dataset cannot be fetched."""
|
||||||
|
|
||||||
|
|
||||||
|
class PortUnavailableError(SessionError):
|
||||||
|
"""Raised when the requested port is already in use."""
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManager:
|
||||||
|
"""Coordinate embedding-atlas containers per incoming request."""
|
||||||
|
|
||||||
|
def __init__(self, app_settings: Settings | None = None) -> None:
|
||||||
|
self.settings = app_settings or settings
|
||||||
|
self.client = docker.DockerClient(base_url=self.settings.docker_url)
|
||||||
|
self._sessions: Dict[str, SessionRecord] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._cleanup_task: Optional[asyncio.Task[None]] = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Initialise caches and background cleanup."""
|
||||||
|
|
||||||
|
await self._load_existing_sessions()
|
||||||
|
if self._cleanup_task is None:
|
||||||
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop background tasks and close Docker handles."""
|
||||||
|
|
||||||
|
if self._cleanup_task is not None:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await self._cleanup_task
|
||||||
|
self._cleanup_task = None
|
||||||
|
await asyncio.to_thread(self.client.close)
|
||||||
|
|
||||||
|
async def create_session(self, payload: SessionCreateRequest) -> SessionRecord:
|
||||||
|
"""Download the dataset and spin up an embedding-atlas container."""
|
||||||
|
|
||||||
|
session_id = payload.normalised_session_id()
|
||||||
|
async with self._lock:
|
||||||
|
if session_id in self._sessions:
|
||||||
|
raise SessionExistsError(f"Session {session_id} already active.")
|
||||||
|
|
||||||
|
host = payload.host or self.settings.container_host
|
||||||
|
if payload.port is not None:
|
||||||
|
is_free = await asyncio.to_thread(self._port_available, payload.port)
|
||||||
|
if not is_free:
|
||||||
|
raise PortUnavailableError(f"Requested port {payload.port} is unavailable")
|
||||||
|
port = payload.port
|
||||||
|
else:
|
||||||
|
port = await asyncio.to_thread(self._allocate_port)
|
||||||
|
expiry_seconds = payload.auto_remove_seconds or self.settings.auto_remove_seconds
|
||||||
|
expires_at = datetime.now(tz=timezone.utc) + timedelta(seconds=expiry_seconds)
|
||||||
|
dataset_path = await self._download_dataset(session_id, payload)
|
||||||
|
|
||||||
|
try:
|
||||||
|
record = await asyncio.to_thread(
|
||||||
|
self._run_container,
|
||||||
|
session_id,
|
||||||
|
dataset_path,
|
||||||
|
payload,
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
expires_at,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Drop downloaded artefacts if startup failed
|
||||||
|
with suppress(Exception):
|
||||||
|
self._purge_dataset(dataset_path)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
self._sessions[session_id] = record
|
||||||
|
return record
|
||||||
|
|
||||||
|
async def close_session(self, session_id: str) -> bool:
|
||||||
|
"""Terminate the associated container and remove cached data."""
|
||||||
|
|
||||||
|
record = await self._get_session(session_id)
|
||||||
|
await asyncio.to_thread(self._stop_container, record.container_name)
|
||||||
|
await asyncio.to_thread(self._purge_dataset, Path(record.dataset_path))
|
||||||
|
async with self._lock:
|
||||||
|
self._sessions.pop(session_id, None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def list_sessions(self) -> Iterable[SessionRecord]:
|
||||||
|
"""Return the internal session registry."""
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
return list(self._sessions.values())
|
||||||
|
|
||||||
|
async def cleanup_expired(self) -> None:
|
||||||
|
"""Stop containers that exceeded their TTL."""
|
||||||
|
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
async with self._lock:
|
||||||
|
sessions = list(self._sessions.values())
|
||||||
|
for record in sessions:
|
||||||
|
if record.expires_at <= now:
|
||||||
|
logger.info("Session %s expired; stopping container %s", record.session_id, record.container_name)
|
||||||
|
with suppress(SessionError, APIError, NotFound):
|
||||||
|
await self.close_session(record.session_id)
|
||||||
|
|
||||||
|
async def _cleanup_loop(self) -> None:
|
||||||
|
"""Periodic cleanup coroutine."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(self.settings.cleanup_interval_seconds)
|
||||||
|
await self.cleanup_expired()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug("Cleanup loop cancelled")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _get_session(self, session_id: str) -> SessionRecord:
|
||||||
|
async with self._lock:
|
||||||
|
record = self._sessions.get(session_id)
|
||||||
|
if record is None:
|
||||||
|
raise SessionNotFoundError(f"Session {session_id} not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
async def _load_existing_sessions(self) -> None:
|
||||||
|
"""Recover managed containers into the local cache on startup."""
|
||||||
|
|
||||||
|
def _list() -> list[Container]:
|
||||||
|
return self.client.containers.list(all=True, filters={"label": MANAGED_LABEL})
|
||||||
|
|
||||||
|
containers = await asyncio.to_thread(_list)
|
||||||
|
restored: Dict[str, SessionRecord] = {}
|
||||||
|
for container in containers:
|
||||||
|
labels = container.labels or {}
|
||||||
|
session_id = labels.get(SESSION_ID_LABEL)
|
||||||
|
if not session_id:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
port = self._extract_port(container)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Unable to restore port mapping for container %s", container.name)
|
||||||
|
continue
|
||||||
|
expires_at = self._parse_expiry(labels.get(EXPIRY_LABEL))
|
||||||
|
dataset_path = labels.get(DATASET_LABEL, "")
|
||||||
|
restored[session_id] = SessionRecord(
|
||||||
|
session_id=session_id,
|
||||||
|
container_id=container.id,
|
||||||
|
container_name=container.name,
|
||||||
|
port=port,
|
||||||
|
host=self.settings.container_host,
|
||||||
|
started_at=self._parse_created(container.attrs.get("Created")),
|
||||||
|
expires_at=expires_at,
|
||||||
|
dataset_path=dataset_path,
|
||||||
|
)
|
||||||
|
async with self._lock:
|
||||||
|
self._sessions.update(restored)
|
||||||
|
if restored:
|
||||||
|
logger.info("Restored %d managed sessions from existing containers", len(restored))
|
||||||
|
|
||||||
|
def _run_container(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
dataset_path: Path,
|
||||||
|
payload: SessionCreateRequest,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
expires_at: datetime,
|
||||||
|
) -> SessionRecord:
|
||||||
|
container_name = f"{self.settings.container_name_prefix}_{session_id}"
|
||||||
|
self._guard_duplicate_container(container_name)
|
||||||
|
|
||||||
|
bind_root = dataset_path.parent
|
||||||
|
container_dataset_path = Path("/session") / dataset_path.name
|
||||||
|
|
||||||
|
command = [
|
||||||
|
"embedding-atlas",
|
||||||
|
str(container_dataset_path),
|
||||||
|
"--host",
|
||||||
|
host,
|
||||||
|
"--port",
|
||||||
|
str(port),
|
||||||
|
"--no-auto-port",
|
||||||
|
]
|
||||||
|
if payload.extra_args:
|
||||||
|
command.extend(payload.extra_args)
|
||||||
|
|
||||||
|
env = payload.environment.copy()
|
||||||
|
env.setdefault("SESSION_ID", session_id)
|
||||||
|
|
||||||
|
labels = default_labels(session_id, expires_at, str(dataset_path))
|
||||||
|
labels.update(payload.labels)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Starting embedding-atlas container %s on port %s (session %s)",
|
||||||
|
container_name,
|
||||||
|
port,
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
port_bindings = {f"{port}/tcp": port}
|
||||||
|
container = self.client.containers.run(
|
||||||
|
payload.image or self.settings.container_image,
|
||||||
|
command,
|
||||||
|
name=container_name,
|
||||||
|
detach=True,
|
||||||
|
environment=env,
|
||||||
|
labels=labels,
|
||||||
|
volumes={
|
||||||
|
str(bind_root): {"bind": "/session", "mode": "ro"}
|
||||||
|
},
|
||||||
|
ports=port_bindings,
|
||||||
|
network=self.settings.container_network,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SessionRecord(
|
||||||
|
session_id=session_id,
|
||||||
|
container_id=container.id,
|
||||||
|
container_name=container.name,
|
||||||
|
port=port,
|
||||||
|
host=host,
|
||||||
|
started_at=datetime.now(tz=timezone.utc),
|
||||||
|
expires_at=expires_at,
|
||||||
|
dataset_path=str(dataset_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _stop_container(self, container_name: str) -> None:
|
||||||
|
try:
|
||||||
|
container = self.client.containers.get(container_name)
|
||||||
|
except NotFound as exc:
|
||||||
|
raise SessionNotFoundError(f"Container {container_name} not found") from exc
|
||||||
|
container.stop(timeout=30)
|
||||||
|
container.remove(v=True)
|
||||||
|
logger.info("Removed container %s", container_name)
|
||||||
|
|
||||||
|
def _purge_dataset(self, dataset_path: Path) -> None:
|
||||||
|
try:
|
||||||
|
dataset_path.unlink(missing_ok=True)
|
||||||
|
parent = dataset_path.parent
|
||||||
|
if parent.exists() and not any(parent.iterdir()):
|
||||||
|
parent.rmdir()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to clean dataset path %s: %s", dataset_path, exc)
|
||||||
|
|
||||||
|
async def _download_dataset(self, session_id: str, payload: SessionCreateRequest) -> Path:
|
||||||
|
filename = payload.input_filename or Path(payload.data_url.path).name
|
||||||
|
if not filename:
|
||||||
|
filename = f"dataset-{session_id}"
|
||||||
|
session_dir = self.settings.session_root / session_id
|
||||||
|
session_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
target = session_dir / filename
|
||||||
|
|
||||||
|
logger.info("Downloading dataset for session %s from %s", session_id, payload.data_url)
|
||||||
|
timeout = httpx.Timeout(self.settings.download_timeout_seconds)
|
||||||
|
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
|
||||||
|
try:
|
||||||
|
async with client.stream("GET", str(payload.data_url)) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
with target.open("wb") as file_handle:
|
||||||
|
async for chunk in response.aiter_bytes(chunk_size=self.settings.download_chunk_size):
|
||||||
|
file_handle.write(chunk)
|
||||||
|
except Exception as exc:
|
||||||
|
raise DatasetDownloadError(f"Failed to fetch dataset: {exc}") from exc
|
||||||
|
return target
|
||||||
|
|
||||||
|
def _allocate_port(self) -> int:
|
||||||
|
if self.settings.port_range_start >= self.settings.port_range_end:
|
||||||
|
return self._random_free_port()
|
||||||
|
candidates = range(self.settings.port_range_start, self.settings.port_range_end + 1)
|
||||||
|
for candidate in candidates:
|
||||||
|
if self._port_available(candidate):
|
||||||
|
return candidate
|
||||||
|
return self._random_free_port()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _random_free_port() -> int:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||||
|
sock.bind(("", 0))
|
||||||
|
return sock.getsockname()[1]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _port_available(port: int) -> bool:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||||
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
result = sock.connect_ex(("127.0.0.1", port))
|
||||||
|
return result != 0
|
||||||
|
|
||||||
|
def _guard_duplicate_container(self, container_name: str) -> None:
|
||||||
|
try:
|
||||||
|
container = self.client.containers.get(container_name)
|
||||||
|
except NotFound:
|
||||||
|
return
|
||||||
|
raise SessionExistsError(
|
||||||
|
f"Container name {container_name} already exists (id={container.id})."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_port(container: Container) -> int:
|
||||||
|
ports = container.attrs.get("NetworkSettings", {}).get("Ports", {})
|
||||||
|
if not ports:
|
||||||
|
raise ValueError("No exposed ports found")
|
||||||
|
# Expect single port binding
|
||||||
|
(container_port, host_bindings), = ports.items()
|
||||||
|
if not host_bindings:
|
||||||
|
raise ValueError("Container port not published")
|
||||||
|
host_port = host_bindings[0].get("HostPort")
|
||||||
|
if host_port is None:
|
||||||
|
raise ValueError("Missing host port")
|
||||||
|
return int(host_port)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_expiry(value: Optional[str]) -> datetime:
|
||||||
|
if not value:
|
||||||
|
return datetime.now(tz=timezone.utc)
|
||||||
|
try:
|
||||||
|
return datetime.fromtimestamp(int(value), tz=timezone.utc)
|
||||||
|
except ValueError:
|
||||||
|
return datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_created(value: Optional[str]) -> datetime:
|
||||||
|
if not value:
|
||||||
|
return datetime.now(tz=timezone.utc)
|
||||||
|
try:
|
||||||
|
return datetime.fromisoformat(value.replace('Z', '+00:00'))
|
||||||
|
except ValueError:
|
||||||
|
return datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SessionManager",
|
||||||
|
"SessionError",
|
||||||
|
"SessionExistsError",
|
||||||
|
"SessionNotFoundError",
|
||||||
|
"DatasetDownloadError",
|
||||||
|
"PortUnavailableError",
|
||||||
|
]
|
||||||
120
src/embedding_backend/main.py
Normal file
120
src/embedding_backend/main.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Entry points for launching the orchestrator as FastAPI or FastMCP."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
|
from .config import Settings, settings
|
||||||
|
from .docker_manager import SessionManager
|
||||||
|
from .routes import get_router
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(app_settings: Optional[Settings] = None) -> FastAPI:
|
||||||
|
"""Create and configure the FastAPI application."""
|
||||||
|
|
||||||
|
app_settings = app_settings or settings
|
||||||
|
manager = SessionManager(app_settings)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(_: FastAPI):
|
||||||
|
await manager.start()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
await manager.stop()
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Embedding Atlas Orchestrator",
|
||||||
|
version="0.1.0",
|
||||||
|
description=(
|
||||||
|
"Session orchestration API that launches short-lived embedding-atlas containers on demand."
|
||||||
|
),
|
||||||
|
docs_url="/docs",
|
||||||
|
redoc_url=None,
|
||||||
|
openapi_url="/openapi.json",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/", tags=["meta"], summary="Service metadata")
|
||||||
|
async def root() -> dict[str, str]:
|
||||||
|
return {"message": "Embedding Atlas session manager is running. Visit /docs for the Swagger UI."}
|
||||||
|
|
||||||
|
app.include_router(get_router(manager))
|
||||||
|
app.state.session_manager = manager
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
_app = create_app()
|
||||||
|
app = _app
|
||||||
|
_mcp_server = FastMCP.from_fastapi(_app, name="embedding-atlas-session-manager")
|
||||||
|
mcp_server = _mcp_server
|
||||||
|
|
||||||
|
|
||||||
|
def get_mcp_server() -> FastMCP:
|
||||||
|
return mcp_server
|
||||||
|
|
||||||
|
|
||||||
|
def run_api() -> None:
|
||||||
|
"""Launch the FastAPI server via uvicorn."""
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=settings.api_host,
|
||||||
|
port=settings.api_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_mcp_stdio_async() -> None:
|
||||||
|
server = get_mcp_server()
|
||||||
|
await server.run_stdio_async()
|
||||||
|
|
||||||
|
|
||||||
|
def run_mcp_stdio() -> None:
|
||||||
|
"""Launch the MCP server over stdio (experimental)."""
|
||||||
|
|
||||||
|
asyncio.run(_run_mcp_stdio_async())
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""CLI entry point allowing API or MCP execution modes."""
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Embedding Atlas orchestrator")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
choices=["api", "mcp"],
|
||||||
|
default="api",
|
||||||
|
help="Select whether to launch the HTTP API or the MCP stdio server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
default=settings.api_host,
|
||||||
|
help="Override FastAPI bind host.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=settings.api_port,
|
||||||
|
help="Override FastAPI bind port.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.mode == "api":
|
||||||
|
logger.info("Starting FastAPI server on %s:%s", args.host, args.port)
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port)
|
||||||
|
else:
|
||||||
|
logger.info("Starting FastMCP stdio server")
|
||||||
|
run_mcp_stdio()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
122
src/embedding_backend/models.py
Normal file
122
src/embedding_backend/models.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Pydantic models for the orchestration API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCreateRequest(BaseModel):
|
||||||
|
"""Incoming payload describing a requested embedding-atlas session."""
|
||||||
|
|
||||||
|
session_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional session identifier; generated when omitted.",
|
||||||
|
)
|
||||||
|
data_url: HttpUrl = Field(
|
||||||
|
description="HTTPS location of the dataset to visualize.",
|
||||||
|
)
|
||||||
|
input_filename: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Filename to persist the downloaded dataset under.",
|
||||||
|
)
|
||||||
|
extra_args: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Additional CLI arguments forwarded to embedding-atlas.",
|
||||||
|
)
|
||||||
|
environment: Dict[str, str] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Environment variables injected into the managed container.",
|
||||||
|
)
|
||||||
|
labels: Dict[str, str] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Extra Docker labels assigned to the managed container.",
|
||||||
|
)
|
||||||
|
image: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Override Docker image for the managed container.",
|
||||||
|
)
|
||||||
|
host: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Host binding forwarded to embedding-atlas (defaults to settings.container_host).",
|
||||||
|
)
|
||||||
|
port: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Preferred host port for the embedding-atlas UI.",
|
||||||
|
)
|
||||||
|
auto_remove_seconds: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Lifetime of the session before automatic cleanup triggers.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def normalised_session_id(self) -> str:
|
||||||
|
return self.session_id or uuid4().hex
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRecord(BaseModel):
|
||||||
|
"""Internal representation of a live session."""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
container_id: str
|
||||||
|
container_name: str
|
||||||
|
port: int
|
||||||
|
host: str
|
||||||
|
started_at: datetime
|
||||||
|
expires_at: datetime
|
||||||
|
dataset_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class SessionResponse(BaseModel):
|
||||||
|
"""HTTP response returned after a session is launched."""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
container_name: str
|
||||||
|
endpoint: str
|
||||||
|
port: int
|
||||||
|
host: str
|
||||||
|
expires_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCloseResponse(BaseModel):
|
||||||
|
"""HTTP response returned after a session is terminated."""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
removed: bool
|
||||||
|
|
||||||
|
|
||||||
|
class SessionListResponse(BaseModel):
|
||||||
|
"""Shape returned when listing active sessions."""
|
||||||
|
|
||||||
|
sessions: List[SessionRecord]
|
||||||
|
|
||||||
|
|
||||||
|
MANAGED_LABEL = "embedding-backend.managed"
|
||||||
|
EXPIRY_LABEL = "embedding-backend.expires-at"
|
||||||
|
SESSION_ID_LABEL = "embedding-backend.session-id"
|
||||||
|
DATASET_LABEL = "embedding-backend.dataset-path"
|
||||||
|
|
||||||
|
|
||||||
|
def default_labels(session_id: str, expires_at: datetime, dataset_path: str) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
MANAGED_LABEL: "true",
|
||||||
|
SESSION_ID_LABEL: session_id,
|
||||||
|
EXPIRY_LABEL: str(int(expires_at.timestamp())),
|
||||||
|
DATASET_LABEL: dataset_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SessionCreateRequest",
|
||||||
|
"SessionRecord",
|
||||||
|
"SessionResponse",
|
||||||
|
"SessionCloseResponse",
|
||||||
|
"SessionListResponse",
|
||||||
|
"default_labels",
|
||||||
|
"MANAGED_LABEL",
|
||||||
|
"EXPIRY_LABEL",
|
||||||
|
"SESSION_ID_LABEL",
|
||||||
|
"DATASET_LABEL",
|
||||||
|
]
|
||||||
71
src/embedding_backend/routes.py
Normal file
71
src/embedding_backend/routes.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""FastAPI routes exposing the session management API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
|
from .docker_manager import (
|
||||||
|
DatasetDownloadError,
|
||||||
|
PortUnavailableError,
|
||||||
|
SessionExistsError,
|
||||||
|
SessionManager,
|
||||||
|
SessionNotFoundError,
|
||||||
|
)
|
||||||
|
from .models import (
|
||||||
|
SessionCloseResponse,
|
||||||
|
SessionCreateRequest,
|
||||||
|
SessionListResponse,
|
||||||
|
SessionRecord,
|
||||||
|
SessionResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_router(manager: SessionManager) -> APIRouter:
|
||||||
|
router = APIRouter(prefix="/sessions", tags=["sessions"])
|
||||||
|
|
||||||
|
async def get_manager() -> SessionManager:
|
||||||
|
return manager
|
||||||
|
|
||||||
|
@router.post("/", response_model=SessionResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_session(
|
||||||
|
payload: SessionCreateRequest,
|
||||||
|
manager: SessionManager = Depends(get_manager),
|
||||||
|
) -> SessionResponse:
|
||||||
|
try:
|
||||||
|
record = await manager.create_session(payload)
|
||||||
|
except SessionExistsError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||||
|
except DatasetDownloadError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
||||||
|
except PortUnavailableError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||||
|
endpoint = f"http://{record.host}:{record.port}"
|
||||||
|
return SessionResponse(
|
||||||
|
session_id=record.session_id,
|
||||||
|
container_name=record.container_name,
|
||||||
|
endpoint=endpoint,
|
||||||
|
port=record.port,
|
||||||
|
host=record.host,
|
||||||
|
expires_at=record.expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete("/{session_id}", response_model=SessionCloseResponse)
|
||||||
|
async def close_session(
|
||||||
|
session_id: str,
|
||||||
|
manager: SessionManager = Depends(get_manager),
|
||||||
|
) -> SessionCloseResponse:
|
||||||
|
try:
|
||||||
|
await manager.close_session(session_id)
|
||||||
|
except SessionNotFoundError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||||
|
return SessionCloseResponse(session_id=session_id, removed=True)
|
||||||
|
|
||||||
|
@router.get("/", response_model=SessionListResponse)
|
||||||
|
async def list_sessions(manager: SessionManager = Depends(get_manager)) -> SessionListResponse:
|
||||||
|
sessions = await manager.list_sessions()
|
||||||
|
return SessionListResponse(sessions=list(sessions))
|
||||||
|
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_router"]
|
||||||
Reference in New Issue
Block a user