修复 metadata 结构问题并完善 README 文档
1. 修复了 comparison.py 中 metadata 结构问题,移除了手动添加的 database 字段,让 make_server() 自动处理数据库配置 2. 完善了 README.md 文档,添加了关于工具使用原则、参数说明、高级用法和模型选择的详细说明 3. 添加了首次使用时模型下载失败的解决方案说明
This commit is contained in:
134
README.md
134
README.md
@@ -33,6 +33,118 @@ export HF_HUB_OFFLINE=1
|
|||||||
export HF_ENDPOINT=https://hf-mirror.com
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 数据可视化工具
|
||||||
|
|
||||||
|
项目包含一个强大的数据可视化工具 [comparison.py](file:///Users/lingyuzeng/project/embedding_atlas/src/visualization/comparison.py),可以比较两个 CSV 文件中的数据并在 2D 空间中可视化。
|
||||||
|
|
||||||
|
### 安装和首次使用注意事项
|
||||||
|
|
||||||
|
首次运行时,系统需要下载 embedding 模型权重(如 `all-MiniLM-L6-v2`)。如果下载失败,请设置 Hugging Face 镜像:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
```
|
||||||
|
|
||||||
|
### 工具使用原则
|
||||||
|
|
||||||
|
在使用该工具时,请遵循以下核心设计原则:
|
||||||
|
|
||||||
|
1. 让 `make_server()` 自动处理数据库配置
|
||||||
|
2. 使用完整的 `props` 格式提供前端所需的所有信息
|
||||||
|
3. 避免手动添加可能与系统自动添加冲突的配置字段
|
||||||
|
|
||||||
|
### 命令行使用方式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/visualization/comparison.py file1.csv file2.csv \
|
||||||
|
--column1 smiles --column2 smiles \
|
||||||
|
--label1 "Dataset A" --label2 "Dataset B" \
|
||||||
|
--interactive --port 5055
|
||||||
|
```
|
||||||
|
|
||||||
|
参数说明:
|
||||||
|
- `file1.csv` 和 `file2.csv`:要比较的两个 CSV 文件
|
||||||
|
- `--column1` 和 `--column2`:分别指定两个文件中用于生成 embedding 的列名
|
||||||
|
- `--label1` 和 `--label2`:在可视化中显示的数据集标签
|
||||||
|
- `--interactive`:启动交互式 Web 查看器
|
||||||
|
- `--port`:指定 Web 服务器端口
|
||||||
|
- `--model`:指定要使用的 embedding 模型(默认:`all-MiniLM-L6-v2`)
|
||||||
|
- `--batch-size`:指定处理数据的批大小(默认:32)
|
||||||
|
- `--output` 或 `-o`:指定输出图像文件路径(默认:`comparison_visualization.png`)
|
||||||
|
- `--host`:指定 Web 服务器主机地址(默认:`0.0.0.0`)
|
||||||
|
|
||||||
|
### Python API 调用方式
|
||||||
|
|
||||||
|
工具也支持作为 Python 模块直接调用:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.visualization.comparison import visualize_csv_comparison
|
||||||
|
|
||||||
|
# 基本用法
|
||||||
|
visualize_csv_comparison(
|
||||||
|
"file1.csv",
|
||||||
|
"file2.csv",
|
||||||
|
column1="smiles",
|
||||||
|
column2="smiles",
|
||||||
|
launch_interactive=True,
|
||||||
|
port=5055
|
||||||
|
)
|
||||||
|
|
||||||
|
# 高级用法 - 自定义模型和参数
|
||||||
|
visualize_csv_comparison(
|
||||||
|
"file1.csv",
|
||||||
|
"file2.csv",
|
||||||
|
column1="smiles",
|
||||||
|
column2="smiles",
|
||||||
|
model="sentence-transformers/all-mpnet-base-v2", # 使用不同的模型
|
||||||
|
batch_size=16, # 调整批处理大小
|
||||||
|
output_path="custom_output.png", # 自定义输出路径
|
||||||
|
launch_interactive=True,
|
||||||
|
port=8080, # 自定义端口
|
||||||
|
umap_args={ # 自定义 UMAP 参数
|
||||||
|
"n_neighbors": 20,
|
||||||
|
"min_dist": 0.2,
|
||||||
|
"metric": "cosine"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 自定义 embedding 服务
|
||||||
|
from src.visualization.comparison import create_embedding_service
|
||||||
|
|
||||||
|
create_embedding_service(
|
||||||
|
["text1", "text2", "text3"], # 第一组文本数据
|
||||||
|
["text4", "text5", "text6"], # 第二组文本数据
|
||||||
|
labels=("Group A", "Group B"),
|
||||||
|
model="sentence-transformers/all-mpnet-base-v2", # 指定模型
|
||||||
|
batch_size=16, # 批处理大小
|
||||||
|
port=5055
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 支持的模型
|
||||||
|
|
||||||
|
工具支持任何兼容 Sentence Transformers 的模型,包括但不限于:
|
||||||
|
|
||||||
|
- `all-MiniLM-L6-v2`(默认)
|
||||||
|
- `all-mpnet-base-v2`
|
||||||
|
- `all-distilroberta-v1`
|
||||||
|
- `paraphrase-multilingual-MiniLM-L12-v2`
|
||||||
|
|
||||||
|
### 高级功能
|
||||||
|
|
||||||
|
1. **自定义 UMAP 参数**:
|
||||||
|
可以通过 `umap_args` 参数调整降维效果:
|
||||||
|
```python
|
||||||
|
umap_args = {
|
||||||
|
"n_neighbors": 15, # 邻居数量
|
||||||
|
"min_dist": 0.1, # 最小距离
|
||||||
|
"metric": "cosine" # 距离度量
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **批处理优化**:
|
||||||
|
对于大型数据集,可以通过调整 `batch_size` 参数来平衡内存使用和处理速度。
|
||||||
|
|
||||||
## 会话编排服务(FastAPI / MCP)
|
## 会话编排服务(FastAPI / MCP)
|
||||||
|
|
||||||
使用 `uv run embedding-backend-api` 可以启动一个同时兼容 FastAPI 与 FastMCP 的后端服务。该服务监听 `/sessions` 路径,负责按需拉起 `embedding-atlas` 容器并在 10 小时后自动清理。
|
使用 `uv run embedding-backend-api` 可以启动一个同时兼容 FastAPI 与 FastMCP 的后端服务。该服务监听 `/sessions` 路径,负责按需拉起 `embedding-atlas` 容器并在 10 小时后自动清理。
|
||||||
@@ -45,7 +157,7 @@ uv run embedding-backend-mcp
|
|||||||
|
|
||||||
### REST API 用法
|
### REST API 用法
|
||||||
|
|
||||||
```bash
|
```
|
||||||
curl -X POST http://localhost:9000/sessions \
|
curl -X POST http://localhost:9000/sessions \
|
||||||
-H 'Content-Type: application/json' \
|
-H 'Content-Type: application/json' \
|
||||||
-d '{
|
-d '{
|
||||||
@@ -79,7 +191,7 @@ curl http://localhost:9000/sessions
|
|||||||
|
|
||||||
示例请求:
|
示例请求:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
curl -X POST http://localhost:9000/sessions \
|
curl -X POST http://localhost:9000/sessions \
|
||||||
-H 'Content-Type: application/json' \
|
-H 'Content-Type: application/json' \
|
||||||
-d '{
|
-d '{
|
||||||
@@ -97,7 +209,7 @@ curl -X POST http://localhost:9000/sessions \
|
|||||||
|
|
||||||
无查询参数,返回当前所有会话的列表:
|
无查询参数,返回当前所有会话的列表:
|
||||||
|
|
||||||
```json
|
```
|
||||||
{
|
{
|
||||||
"sessions": [
|
"sessions": [
|
||||||
{
|
{
|
||||||
@@ -121,7 +233,7 @@ curl -X POST http://localhost:9000/sessions \
|
|||||||
|
|
||||||
根目录的 `fastmcp.json` 示例可直接将本项目注册为 MCP 工具:
|
根目录的 `fastmcp.json` 示例可直接将本项目注册为 MCP 工具:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
uv run embedding-backend-mcp
|
uv run embedding-backend-mcp
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -129,14 +241,14 @@ FastMCP 客户端加载该配置后,可用标准 MCP 协议转发同一套 RES
|
|||||||
|
|
||||||
## 命令行生成嵌入可视化交互
|
## 命令行生成嵌入可视化交互
|
||||||
|
|
||||||
```bash
|
```
|
||||||
uv run embedding-atlas data/drugbank_pre_filtered_mordred_qed_id_selfies.csv --text smiles
|
uv run embedding-atlas data/drugbank_pre_filtered_mordred_qed_id_selfies.csv --text smiles
|
||||||
uv run embedding-atlas data/drugbank_pre_filtered_mordred_qed_id_selfies.csv --export-application data/my_visualization.zip
|
uv run embedding-atlas data/drugbank_pre_filtered_mordred_qed_id_selfies.csv --export-application data/my_visualization.zip
|
||||||
```
|
```
|
||||||
|
|
||||||
`embedding-atlas` 更多用法示例:
|
`embedding-atlas` 更多用法示例:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
# 本地文件
|
# 本地文件
|
||||||
embedding-atlas dataset.parquet
|
embedding-atlas dataset.parquet
|
||||||
# Hugging Face 数据集
|
# Hugging Face 数据集
|
||||||
@@ -153,7 +265,7 @@ embedding-atlas dataset.parquet --x projection_x --y projection_y
|
|||||||
|
|
||||||
### Python API使用方法
|
### Python API使用方法
|
||||||
|
|
||||||
```python
|
```
|
||||||
from script.visualize_csv_comparison import visualize_csv_comparison, create_embedding_service
|
from script.visualize_csv_comparison import visualize_csv_comparison, create_embedding_service
|
||||||
|
|
||||||
# 比较两个CSV文件
|
# 比较两个CSV文件
|
||||||
@@ -190,7 +302,7 @@ create_embedding_service(
|
|||||||
|
|
||||||
### 命令行使用方法
|
### 命令行使用方法
|
||||||
|
|
||||||
```bash
|
```
|
||||||
# 基本用法
|
# 基本用法
|
||||||
python script/visualize_csv_comparison.py file1.csv file2.csv
|
python script/visualize_csv_comparison.py file1.csv file2.csv
|
||||||
|
|
||||||
@@ -226,7 +338,7 @@ python script/visualize_csv_comparison.py file1.csv file2.csv \
|
|||||||
|
|
||||||
## 划分 MolGen 第一轮微调数据集
|
## 划分 MolGen 第一轮微调数据集
|
||||||
|
|
||||||
```bash
|
```
|
||||||
uv run python script/split_drugbank.py \
|
uv run python script/split_drugbank.py \
|
||||||
--in-csv data/drugbank_pre_filtered_mordred_qed_id_selfies.csv \
|
--in-csv data/drugbank_pre_filtered_mordred_qed_id_selfies.csv \
|
||||||
--out-dir splits_v2 --seed 20250922 \
|
--out-dir splits_v2 --seed 20250922 \
|
||||||
@@ -241,13 +353,13 @@ uv run python script/split_drugbank.py \
|
|||||||
|
|
||||||
合并数据集:
|
合并数据集:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
uv run python script/merge_splits.py --input-dir splits_v2/ --output data/drugbank_split_merge.csv
|
uv run python script/merge_splits.py --input-dir splits_v2/ --output data/drugbank_split_merge.csv
|
||||||
```
|
```
|
||||||
|
|
||||||
可视化:
|
可视化:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
uv run embedding-atlas data/drugbank_split_merge.csv --text smiles
|
uv run embedding-atlas data/drugbank_split_merge.csv --text smiles
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50
|
|||||||
from embedding_atlas.utils import Hasher
|
from embedding_atlas.utils import Hasher
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
# 创建metadata - 使用与CLI一致的props字段结构
|
# 使用props格式,不添加database字段
|
||||||
metadata = {
|
metadata = {
|
||||||
"props": {
|
"props": {
|
||||||
"data": {
|
"data": {
|
||||||
@@ -38,6 +38,7 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50
|
|||||||
},
|
},
|
||||||
"initialState": {"version": "0.0.0"}
|
"initialState": {"version": "0.0.0"}
|
||||||
}
|
}
|
||||||
|
# 不要添加database字段!
|
||||||
}
|
}
|
||||||
|
|
||||||
# 生成数据集标识符
|
# 生成数据集标识符
|
||||||
@@ -56,7 +57,7 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50
|
|||||||
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
|
# make_server()会自动添加database配置
|
||||||
app = make_server(dataset, static_path=static_path)
|
app = make_server(dataset, static_path=static_path)
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@@ -122,7 +123,7 @@ def create_embedding_service(
|
|||||||
print(f"Error computing projections: {e}")
|
print(f"Error computing projections: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 4. 创建metadata (使用与CLI一致的props字段结构)
|
# 4. 创建metadata (使用props格式,不添加database字段)
|
||||||
metadata = {
|
metadata = {
|
||||||
"props": {
|
"props": {
|
||||||
"data": {
|
"data": {
|
||||||
@@ -136,15 +137,17 @@ def create_embedding_service(
|
|||||||
},
|
},
|
||||||
"initialState": {"version": "0.0.0"}
|
"initialState": {"version": "0.0.0"}
|
||||||
}
|
}
|
||||||
|
# 不要添加database字段!
|
||||||
}
|
}
|
||||||
|
|
||||||
# 5. 保存原始DataFrame用于DataSource
|
# 5. 保存原始DataFrame用于DataSource
|
||||||
df_for_datasource = combined_df.copy()
|
df_for_datasource = combined_df.copy()
|
||||||
|
|
||||||
# 6. 转换neighbors列为JSON字符串(仅用于保存parquet)
|
# 6. 转换neighbors列为JSON字符串(仅用于保存parquet)
|
||||||
if '__neighbors' in combined_df.columns:
|
df_for_save = combined_df.copy()
|
||||||
|
if '__neighbors' in df_for_save.columns:
|
||||||
import json
|
import json
|
||||||
combined_df['__neighbors'] = combined_df['__neighbors'].apply(
|
df_for_save['__neighbors'] = df_for_save['__neighbors'].apply(
|
||||||
lambda x: json.dumps({
|
lambda x: json.dumps({
|
||||||
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
|
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
|
||||||
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
|
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
|
||||||
@@ -153,7 +156,7 @@ def create_embedding_service(
|
|||||||
|
|
||||||
# 7. 保存parquet文件
|
# 7. 保存parquet文件
|
||||||
parquet_path = "combined_dataset.parquet"
|
parquet_path = "combined_dataset.parquet"
|
||||||
combined_df.to_parquet(parquet_path, index=False)
|
df_for_save.to_parquet(parquet_path, index=False)
|
||||||
print(f"Dataset saved to {parquet_path}")
|
print(f"Dataset saved to {parquet_path}")
|
||||||
|
|
||||||
# 8. 生成数据集标识符
|
# 8. 生成数据集标识符
|
||||||
@@ -170,7 +173,7 @@ def create_embedding_service(
|
|||||||
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 11. 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
|
# 11. make_server()会自动添加database配置
|
||||||
app = make_server(dataset, static_path=static_path)
|
app = make_server(dataset, static_path=static_path)
|
||||||
import uvicorn
|
import uvicorn
|
||||||
print(f"Starting interactive viewer on http://{host}:{port}")
|
print(f"Starting interactive viewer on http://{host}:{port}")
|
||||||
@@ -328,11 +331,11 @@ def visualize_csv_comparison(
|
|||||||
|
|
||||||
# Convert __neighbors column to JSON strings before saving
|
# Convert __neighbors column to JSON strings before saving
|
||||||
# This is needed because DuckDB-WASM expects JSON strings for complex data types
|
# This is needed because DuckDB-WASM expects JSON strings for complex data types
|
||||||
if '__neighbors' in combined_df.columns:
|
df_for_save = combined_df.copy()
|
||||||
|
if '__neighbors' in df_for_save.columns:
|
||||||
import json
|
import json
|
||||||
# 为保存到 parquet 文件的数据转换为 JSON 字符串
|
# 为保存到 parquet 文件的数据转换为 JSON 字符串
|
||||||
combined_df_for_save = combined_df.copy()
|
df_for_save['__neighbors'] = df_for_save['__neighbors'].apply(
|
||||||
combined_df_for_save['__neighbors'] = combined_df['__neighbors'].apply(
|
|
||||||
lambda x: json.dumps({
|
lambda x: json.dumps({
|
||||||
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
|
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
|
||||||
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
|
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
|
||||||
@@ -341,11 +344,7 @@ def visualize_csv_comparison(
|
|||||||
|
|
||||||
# Save dataset for interactive viewing
|
# Save dataset for interactive viewing
|
||||||
parquet_path = "combined_dataset.parquet"
|
parquet_path = "combined_dataset.parquet"
|
||||||
# 保存转换为 JSON 字符串的 DataFrame
|
df_for_save.to_parquet(parquet_path, index=False)
|
||||||
if '__neighbors' in combined_df.columns:
|
|
||||||
combined_df_for_save.to_parquet(parquet_path, index=False)
|
|
||||||
else:
|
|
||||||
combined_df.to_parquet(parquet_path, index=False)
|
|
||||||
print(f"Dataset saved to {parquet_path}")
|
print(f"Dataset saved to {parquet_path}")
|
||||||
|
|
||||||
# Launch interactive viewer with original DataFrame (not converted to JSON)
|
# Launch interactive viewer with original DataFrame (not converted to JSON)
|
||||||
|
|||||||
Reference in New Issue
Block a user