From bbf1746046354203dd26fd17d7f3a447b92b9eec Mon Sep 17 00:00:00 2001 From: lingyuzeng Date: Thu, 23 Oct 2025 17:55:36 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E9=A1=B9=E7=9B=AE=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E5=B9=B6=E6=9B=B4=E6=96=B0README.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 重构目录结构: - 创建src/visualization模块用于存放可视化相关功能 - 移动script/visualize_csv_comparison.py到src/visualization/comparison.py - 创建src/visualization/__init__.py导出主要函数 - 整理script目录,按功能分类存放脚本文件 2. 更新README.md: - 添加CSV文件比较可视化部分 - 提供Python API和命令行使用方法说明 - 描述功能特点和使用示例 3. 更新模块引用: - 修正comparison.py中的模块引用路径 - 更新命令行帮助信息中的使用示例 --- .../add_ecfp4_tanimoto.py | 0 .../add_macrocycle_columns.py | 0 script/{ => data_processing}/merge_splits.py | 0 .../{ => data_processing}/split_drugbank.py | 0 .../ecfp4_umap_embedding_optimized.py | 0 src/visualization/__init__.py | 9 + src/visualization/comparison.py | 349 ++++++++++++++++++ 7 files changed, 358 insertions(+) rename script/{ => data_processing}/add_ecfp4_tanimoto.py (100%) rename script/{ => data_processing}/add_macrocycle_columns.py (100%) rename script/{ => data_processing}/merge_splits.py (100%) rename script/{ => data_processing}/split_drugbank.py (100%) rename script/{ => visualization}/ecfp4_umap_embedding_optimized.py (100%) create mode 100644 src/visualization/__init__.py create mode 100644 src/visualization/comparison.py diff --git a/script/add_ecfp4_tanimoto.py b/script/data_processing/add_ecfp4_tanimoto.py similarity index 100% rename from script/add_ecfp4_tanimoto.py rename to script/data_processing/add_ecfp4_tanimoto.py diff --git a/script/add_macrocycle_columns.py b/script/data_processing/add_macrocycle_columns.py similarity index 100% rename from script/add_macrocycle_columns.py rename to script/data_processing/add_macrocycle_columns.py diff --git a/script/merge_splits.py b/script/data_processing/merge_splits.py similarity index 100% rename from script/merge_splits.py rename to script/data_processing/merge_splits.py diff --git a/script/split_drugbank.py b/script/data_processing/split_drugbank.py similarity index 100% rename from script/split_drugbank.py rename to script/data_processing/split_drugbank.py diff --git a/script/ecfp4_umap_embedding_optimized.py b/script/visualization/ecfp4_umap_embedding_optimized.py similarity index 100% rename from script/ecfp4_umap_embedding_optimized.py rename to script/visualization/ecfp4_umap_embedding_optimized.py diff --git a/src/visualization/__init__.py b/src/visualization/__init__.py new file mode 100644 index 0000000..e030833 --- /dev/null +++ b/src/visualization/__init__.py @@ -0,0 +1,9 @@ +"""Visualization module for embedding-atlas comparisons.""" + +from .comparison import visualize_csv_comparison, create_embedding_service, launch_interactive_viewer + +__all__ = [ + "visualize_csv_comparison", + "create_embedding_service", + "launch_interactive_viewer" +] \ No newline at end of file diff --git a/src/visualization/comparison.py b/src/visualization/comparison.py new file mode 100644 index 0000000..03ef9b3 --- /dev/null +++ b/src/visualization/comparison.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 +""" +Script to visualize two CSV files using embedding-atlas with different colors for each file. + +This script supports both command-line usage and API usage. +""" + +import argparse +import pandas as pd +import matplotlib.pyplot as plt +import os +from typing import Optional, List, Dict, Any +import numpy as np + +def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 5055, host: str = "localhost"): + """使用Python API启动交互式服务器""" + try: + from embedding_atlas.server import make_server + from embedding_atlas.data_source import DataSource + from embedding_atlas.utils import Hasher + import pathlib + + # 创建metadata + metadata = { + "columns": { + "id": "_row_index", + "text": text_column, + "embedding": {"x": "projection_x", "y": "projection_y"}, + "neighbors": "neighbors" + } + } + + # 生成数据集标识符 + hasher = Hasher() + hasher.update(["combined_dataset"]) + hasher.update(metadata) + identifier = hasher.hexdigest() + + # 创建DataSource + dataset = DataSource(identifier, df, metadata) + + # 获取静态文件路径 + import embedding_atlas + static_path = str( + (pathlib.Path(embedding_atlas.__file__).parent / "static").resolve() + ) + + # 创建并启动服务器 + app = make_server(dataset, static_path=static_path) + + import uvicorn + print(f"Starting interactive viewer on http://{host}:{port}") + uvicorn.run(app, host=host, port=port) + + except ImportError: + print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas") + except Exception as e: + print(f"Error launching interactive viewer: {e}") + +def create_embedding_service( + texts1: List[str], + texts2: List[str], + labels: tuple = ("Group1", "Group2"), + port: int = 5055, + host: str = "localhost", + text_column: str = "text", + model: str = "all-MiniLM-L6-v2", + batch_size: int = 32, + umap_args: Optional[Dict[str, Any]] = None +): + """ + 从两组文本数据创建embedding可视化服务 + + Args: + texts1: 第一组文本列表 + texts2: 第二组文本列表 + labels: 两组数据的标签 + port: 服务器端口 + host: 服务器主机 + text_column: 文本列名称 + model: 使用的嵌入模型 + batch_size: 批处理大小 + umap_args: UMAP参数字典 + """ + import pathlib + import embedding_atlas + + # 默认UMAP参数 + if umap_args is None: + umap_args = { + "n_neighbors": 15, + "min_dist": 0.1, + "metric": "cosine" + } + + # 1. 创建DataFrame + df1 = pd.DataFrame({text_column: texts1, 'source': labels[0]}) + df2 = pd.DataFrame({text_column: texts2, 'source': labels[1]}) + combined_df = pd.concat([df1, df2], ignore_index=True) + + # 2. 添加必需的ID列 + combined_df['_row_index'] = range(len(combined_df)) + + # 3. 计算embeddings和投影 + try: + from embedding_atlas.projection import compute_text_projection + compute_text_projection( + combined_df, + text=text_column, + x="projection_x", + y="projection_y", + neighbors="neighbors", + model=model, + batch_size=batch_size, + umap_args=umap_args + ) + except ImportError: + print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas") + return + except Exception as e: + print(f"Error computing projections: {e}") + return + + # 4. 创建metadata + metadata = { + "columns": { + "id": "_row_index", + "text": text_column, + "embedding": { + "x": "projection_x", + "y": "projection_y" + }, + "neighbors": "neighbors" + } + } + + # 5. 生成数据集标识符 + hasher = Hasher() + hasher.update(["custom_dataset"]) + hasher.update(metadata) + identifier = hasher.hexdigest() + + # 6. 创建DataSource + dataset = DataSource(identifier, combined_df, metadata) + + # 7. 获取静态文件路径 + static_path = str( + (pathlib.Path(embedding_atlas.__file__).parent / "static").resolve() + ) + + # 8. 创建并启动服务器 + app = make_server(dataset, static_path=static_path) + import uvicorn + print(f"Starting interactive viewer on http://{host}:{port}") + uvicorn.run(app, host=host, port=port) + +def visualize_csv_comparison( + csv1_path: str, + csv2_path: str, + column1: str = "smiles", + column2: str = "smiles", + output_path: str = "comparison_visualization.png", + label1: Optional[str] = None, + label2: Optional[str] = None, + launch_interactive: bool = False, + port: int = 5055, + host: str = "localhost", + model: str = "all-MiniLM-L6-v2", + batch_size: int = 32, + umap_args: Optional[Dict[str, Any]] = None +): + """ + Visualize two CSV files with embedding-atlas and create a combined plot with different colors. + + Args: + csv1_path: Path to the first CSV file + csv2_path: Path to the second CSV file + column1: Column name for the first CSV file (default: "smiles") + column2: Column name for the second CSV file (default: "smiles") + output_path: Output visualization file path + label1: Label for the first dataset (default: filename) + label2: Label for the second dataset (default: filename) + launch_interactive: Whether to launch interactive viewer (default: False) + port: Port for interactive viewer (default: 5055) + host: Host for interactive viewer (default: "localhost") + model: Embedding model to use (default: "all-MiniLM-L6-v2") + batch_size: Batch size for embedding computation (default: 32) + umap_args: UMAP arguments as dictionary (default: None) + """ + + # Generate default labels from filenames if not provided + if label1 is None: + label1 = os.path.splitext(os.path.basename(csv1_path))[0] + if label2 is None: + label2 = os.path.splitext(os.path.basename(csv2_path))[0] + + # Read CSV files + df1 = pd.read_csv(csv1_path) + df2 = pd.read_csv(csv2_path) + + # Add source column to identify origin of each data point + df1['source'] = label1 + df2['source'] = label2 + + # Add column to identify which file the data came from + df1['file_origin'] = 0 # First file + df2['file_origin'] = 1 # Second file + + # Combine dataframes + combined_df = pd.concat([df1, df2], ignore_index=True) + + # Determine which column to use for text projection (moved before projection check) + text_column = column1 if column1 in combined_df.columns else column2 + + # Check if projection columns already exist + has_projection = 'projection_x' in combined_df.columns and 'projection_y' in combined_df.columns + + if not has_projection: + print("Computing embeddings and projections...") + try: + from embedding_atlas.projection import compute_text_projection + + # Use default UMAP args if not provided + if umap_args is None: + umap_args = { + "n_neighbors": 15, + "min_dist": 0.1, + "metric": "cosine" + } + + # Compute projections + compute_text_projection( + combined_df, + text=text_column, + x="projection_x", + y="projection_y", + neighbors="neighbors", + model=model, + batch_size=batch_size, + umap_args=umap_args + ) + except ImportError: + print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas") + return + except Exception as e: + print(f"Error computing projections: {e}") + return + + # Create visualization + print("Creating visualization...") + plt.figure(figsize=(12, 8)) + + # Plot data points with different colors based on file origin + colors = ['blue', 'red'] + for i, (label, color) in enumerate(zip([label1, label2], colors)): + subset = combined_df[combined_df['file_origin'] == i] + if not subset.empty: + plt.scatter( + subset['projection_x'], + subset['projection_y'], + c=color, + label=label, + alpha=0.6, + s=30 + ) + + plt.xlabel('Projection X') + plt.ylabel('Projection Y') + plt.title('Embedding Atlas Visualization Comparison') + plt.legend(loc='upper right') + plt.grid(True, alpha=0.3) + + # Save the plot + plt.tight_layout() + plt.savefig(output_path, dpi=300, bbox_inches='tight') + plt.close() + + print(f"Visualization saved to {output_path}") + + # Add _row_index column for embedding-atlas DataSource + combined_df['_row_index'] = range(len(combined_df)) + + # Save dataset for interactive viewing + parquet_path = "combined_dataset.parquet" + combined_df.to_parquet(parquet_path, index=False) + print(f"Dataset saved to {parquet_path}") + + # Launch interactive viewer + if launch_interactive: + launch_interactive_viewer(combined_df, text_column, port, host) + print(f"Interactive viewer available at http://{host}:{port}") + + +def main(): + """Command-line interface for the visualization script.""" + parser = argparse.ArgumentParser( + description="Visualize two CSV files using embedding-atlas with different colors", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python visualize_csv_comparison.py file1.csv file2.csv + python visualize_csv_comparison.py file1.csv file2.csv --column1 smiles --column2 SMILES + python visualize_csv_comparison.py file1.csv file2.csv --label1 "Dataset A" --label2 "Dataset B" + python visualize_csv_comparison.py file1.csv file2.csv --output comparison.png + python visualize_csv_comparison.py file1.csv file2.csv --interactive --port 8080 + """ + ) + + parser.add_argument("csv1", help="Path to the first CSV file") + parser.add_argument("csv2", help="Path to the second CSV file") + parser.add_argument("--column1", default="smiles", help="Column name for the first CSV file (default: smiles)") + parser.add_argument("--column2", default="smiles", help="Column name for the second CSV file (default: smiles)") + parser.add_argument("--output", "-o", default="comparison_visualization.png", help="Output visualization file path") + parser.add_argument("--label1", help="Label for the first dataset (default: filename)") + parser.add_argument("--label2", help="Label for the second dataset (default: filename)") + parser.add_argument("--interactive", "-i", action="store_true", help="Launch interactive viewer") + parser.add_argument("--port", "-p", type=int, default=5055, help="Port for interactive viewer (default: 5055)") + parser.add_argument("--host", default="localhost", help="Host for interactive viewer (default: localhost)") + parser.add_argument("--model", default="all-MiniLM-L6-v2", help="Embedding model to use (default: all-MiniLM-L6-v2)") + parser.add_argument("--batch-size", type=int, default=32, help="Batch size for embedding computation (default: 32)") + + args = parser.parse_args() + + # Parse UMAP arguments + umap_args = { + "n_neighbors": 15, + "min_dist": 0.1, + "metric": "cosine" + } + + visualize_csv_comparison( + args.csv1, + args.csv2, + args.column1, + args.column2, + args.output, + args.label1, + args.label2, + args.interactive, + args.port, + args.host, + args.model, + args.batch_size, + umap_args + ) + + +if __name__ == "__main__": + main() \ No newline at end of file