diff --git a/README.md b/README.md index 82cc190..e0362c2 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,118 @@ export HF_HUB_OFFLINE=1 export HF_ENDPOINT=https://hf-mirror.com ``` +## 数据可视化工具 + +项目包含一个强大的数据可视化工具 [comparison.py](file:///Users/lingyuzeng/project/embedding_atlas/src/visualization/comparison.py),可以比较两个 CSV 文件中的数据并在 2D 空间中可视化。 + +### 安装和首次使用注意事项 + +首次运行时,系统需要下载 embedding 模型权重(如 `all-MiniLM-L6-v2`)。如果下载失败,请设置 Hugging Face 镜像: + +```bash +export HF_ENDPOINT=https://hf-mirror.com +``` + +### 工具使用原则 + +在使用该工具时,请遵循以下核心设计原则: + +1. 让 `make_server()` 自动处理数据库配置 +2. 使用完整的 `props` 格式提供前端所需的所有信息 +3. 避免手动添加可能与系统自动添加冲突的配置字段 + +### 命令行使用方式 + +```bash +python src/visualization/comparison.py file1.csv file2.csv \ + --column1 smiles --column2 smiles \ + --label1 "Dataset A" --label2 "Dataset B" \ + --interactive --port 5055 +``` + +参数说明: +- `file1.csv` 和 `file2.csv`:要比较的两个 CSV 文件 +- `--column1` 和 `--column2`:分别指定两个文件中用于生成 embedding 的列名 +- `--label1` 和 `--label2`:在可视化中显示的数据集标签 +- `--interactive`:启动交互式 Web 查看器 +- `--port`:指定 Web 服务器端口 +- `--model`:指定要使用的 embedding 模型(默认:`all-MiniLM-L6-v2`) +- `--batch-size`:指定处理数据的批大小(默认:32) +- `--output` 或 `-o`:指定输出图像文件路径(默认:`comparison_visualization.png`) +- `--host`:指定 Web 服务器主机地址(默认:`0.0.0.0`) + +### Python API 调用方式 + +工具也支持作为 Python 模块直接调用: + +```python +from src.visualization.comparison import visualize_csv_comparison + +# 基本用法 +visualize_csv_comparison( + "file1.csv", + "file2.csv", + column1="smiles", + column2="smiles", + launch_interactive=True, + port=5055 +) + +# 高级用法 - 自定义模型和参数 +visualize_csv_comparison( + "file1.csv", + "file2.csv", + column1="smiles", + column2="smiles", + model="sentence-transformers/all-mpnet-base-v2", # 使用不同的模型 + batch_size=16, # 调整批处理大小 + output_path="custom_output.png", # 自定义输出路径 + launch_interactive=True, + port=8080, # 自定义端口 + umap_args={ # 自定义 UMAP 参数 + "n_neighbors": 20, + "min_dist": 0.2, + "metric": "cosine" + } +) + +# 自定义 embedding 服务 +from src.visualization.comparison import create_embedding_service + +create_embedding_service( + ["text1", "text2", "text3"], # 第一组文本数据 + ["text4", "text5", "text6"], # 第二组文本数据 + labels=("Group A", "Group B"), + model="sentence-transformers/all-mpnet-base-v2", # 指定模型 + batch_size=16, # 批处理大小 + port=5055 +) +``` + +### 支持的模型 + +工具支持任何兼容 Sentence Transformers 的模型,包括但不限于: + +- `all-MiniLM-L6-v2`(默认) +- `all-mpnet-base-v2` +- `all-distilroberta-v1` +- `paraphrase-multilingual-MiniLM-L12-v2` + +### 高级功能 + +1. **自定义 UMAP 参数**: + 可以通过 `umap_args` 参数调整降维效果: + ```python + umap_args = { + "n_neighbors": 15, # 邻居数量 + "min_dist": 0.1, # 最小距离 + "metric": "cosine" # 距离度量 + } + ``` + +2. **批处理优化**: + 对于大型数据集,可以通过调整 `batch_size` 参数来平衡内存使用和处理速度。 + ## 会话编排服务(FastAPI / MCP) 使用 `uv run embedding-backend-api` 可以启动一个同时兼容 FastAPI 与 FastMCP 的后端服务。该服务监听 `/sessions` 路径,负责按需拉起 `embedding-atlas` 容器并在 10 小时后自动清理。 @@ -45,7 +157,7 @@ uv run embedding-backend-mcp ### REST API 用法 -```bash +``` curl -X POST http://localhost:9000/sessions \ -H 'Content-Type: application/json' \ -d '{ @@ -79,7 +191,7 @@ curl http://localhost:9000/sessions 示例请求: -```bash +``` curl -X POST http://localhost:9000/sessions \ -H 'Content-Type: application/json' \ -d '{ @@ -97,7 +209,7 @@ curl -X POST http://localhost:9000/sessions \ 无查询参数,返回当前所有会话的列表: -```json +``` { "sessions": [ { @@ -121,7 +233,7 @@ curl -X POST http://localhost:9000/sessions \ 根目录的 `fastmcp.json` 示例可直接将本项目注册为 MCP 工具: -```bash +``` uv run embedding-backend-mcp ``` @@ -129,14 +241,14 @@ FastMCP 客户端加载该配置后,可用标准 MCP 协议转发同一套 RES ## 命令行生成嵌入可视化交互 -```bash +``` uv run embedding-atlas data/drugbank_pre_filtered_mordred_qed_id_selfies.csv --text smiles uv run embedding-atlas data/drugbank_pre_filtered_mordred_qed_id_selfies.csv --export-application data/my_visualization.zip ``` `embedding-atlas` 更多用法示例: -```bash +``` # 本地文件 embedding-atlas dataset.parquet # Hugging Face 数据集 @@ -153,7 +265,7 @@ embedding-atlas dataset.parquet --x projection_x --y projection_y ### Python API使用方法 -```python +``` from script.visualize_csv_comparison import visualize_csv_comparison, create_embedding_service # 比较两个CSV文件 @@ -190,7 +302,7 @@ create_embedding_service( ### 命令行使用方法 -```bash +``` # 基本用法 python script/visualize_csv_comparison.py file1.csv file2.csv @@ -226,7 +338,7 @@ python script/visualize_csv_comparison.py file1.csv file2.csv \ ## 划分 MolGen 第一轮微调数据集 -```bash +``` uv run python script/split_drugbank.py \ --in-csv data/drugbank_pre_filtered_mordred_qed_id_selfies.csv \ --out-dir splits_v2 --seed 20250922 \ @@ -241,13 +353,13 @@ uv run python script/split_drugbank.py \ 合并数据集: -```bash +``` uv run python script/merge_splits.py --input-dir splits_v2/ --output data/drugbank_split_merge.csv ``` 可视化: -```bash +``` uv run embedding-atlas data/drugbank_split_merge.csv --text smiles ``` diff --git a/src/visualization/comparison.py b/src/visualization/comparison.py index f86b852..186ab7b 100644 --- a/src/visualization/comparison.py +++ b/src/visualization/comparison.py @@ -27,7 +27,7 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50 from embedding_atlas.utils import Hasher import pathlib - # 创建metadata - 使用与CLI一致的props字段结构 + # 使用props格式,不添加database字段 metadata = { "props": { "data": { @@ -38,6 +38,7 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50 }, "initialState": {"version": "0.0.0"} } + # 不要添加database字段! } # 生成数据集标识符 @@ -56,7 +57,7 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50 (pathlib.Path(embedding_atlas.__file__).parent / "static").resolve() ) - # 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm") + # make_server()会自动添加database配置 app = make_server(dataset, static_path=static_path) import uvicorn @@ -122,7 +123,7 @@ def create_embedding_service( print(f"Error computing projections: {e}") return - # 4. 创建metadata (使用与CLI一致的props字段结构) + # 4. 创建metadata (使用props格式,不添加database字段) metadata = { "props": { "data": { @@ -136,15 +137,17 @@ def create_embedding_service( }, "initialState": {"version": "0.0.0"} } + # 不要添加database字段! } # 5. 保存原始DataFrame用于DataSource df_for_datasource = combined_df.copy() # 6. 转换neighbors列为JSON字符串(仅用于保存parquet) - if '__neighbors' in combined_df.columns: + df_for_save = combined_df.copy() + if '__neighbors' in df_for_save.columns: import json - combined_df['__neighbors'] = combined_df['__neighbors'].apply( + df_for_save['__neighbors'] = df_for_save['__neighbors'].apply( lambda x: json.dumps({ 'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']), 'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances']) @@ -153,7 +156,7 @@ def create_embedding_service( # 7. 保存parquet文件 parquet_path = "combined_dataset.parquet" - combined_df.to_parquet(parquet_path, index=False) + df_for_save.to_parquet(parquet_path, index=False) print(f"Dataset saved to {parquet_path}") # 8. 生成数据集标识符 @@ -170,7 +173,7 @@ def create_embedding_service( (pathlib.Path(embedding_atlas.__file__).parent / "static").resolve() ) - # 11. 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm") + # 11. make_server()会自动添加database配置 app = make_server(dataset, static_path=static_path) import uvicorn print(f"Starting interactive viewer on http://{host}:{port}") @@ -328,11 +331,11 @@ def visualize_csv_comparison( # Convert __neighbors column to JSON strings before saving # This is needed because DuckDB-WASM expects JSON strings for complex data types - if '__neighbors' in combined_df.columns: + df_for_save = combined_df.copy() + if '__neighbors' in df_for_save.columns: import json # 为保存到 parquet 文件的数据转换为 JSON 字符串 - combined_df_for_save = combined_df.copy() - combined_df_for_save['__neighbors'] = combined_df['__neighbors'].apply( + df_for_save['__neighbors'] = df_for_save['__neighbors'].apply( lambda x: json.dumps({ 'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']), 'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances']) @@ -341,11 +344,7 @@ def visualize_csv_comparison( # Save dataset for interactive viewing parquet_path = "combined_dataset.parquet" - # 保存转换为 JSON 字符串的 DataFrame - if '__neighbors' in combined_df.columns: - combined_df_for_save.to_parquet(parquet_path, index=False) - else: - combined_df.to_parquet(parquet_path, index=False) + df_for_save.to_parquet(parquet_path, index=False) print(f"Dataset saved to {parquet_path}") # Launch interactive viewer with original DataFrame (not converted to JSON)