#!/usr/bin/env python3 """ Script to visualize two CSV files using embedding-atlas with different colors for each file. This script supports both command-line usage and API usage. """ import argparse import pandas as pd import matplotlib.pyplot as plt import os from typing import Optional, List, Dict, Any import numpy as np def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 5055, host: str = "0.0.0.0"): """使用Python API启动交互式服务器 Args: df: 原始的pandas DataFrame,包含投影坐标和邻居信息(作为Python对象,而非JSON字符串) text_column: 用作文本内容的列名 port: 启动服务器的端口 host: 服务器绑定的主机地址 """ try: from embedding_atlas.server import make_server from embedding_atlas.data_source import DataSource from embedding_atlas.utils import Hasher import pathlib # 创建metadata - 使用与CLI一致的props字段结构 metadata = { "props": { "data": { "id": "_row_index", "text": text_column, "projection": {"x": "projection_x", "y": "projection_y"}, "neighbors": "__neighbors" }, "initialState": {"version": "0.0.0"} } } # 生成数据集标识符 hasher = Hasher() hasher.update(["combined_dataset"]) hasher.update(metadata) identifier = hasher.hexdigest() # 创建DataSource - 直接使用传入的原始DataFrame # 确保df['__neighbors'] 是 Python dict + numpy arrays 的结构,而不是 JSON 字符串 dataset = DataSource(identifier, df, metadata) # 获取静态文件路径 import embedding_atlas static_path = str( (pathlib.Path(embedding_atlas.__file__).parent / "static").resolve() ) # 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm") app = make_server(dataset, static_path=static_path) import uvicorn print(f"Starting interactive viewer on http://{host}:{port}") uvicorn.run(app, host=host, port=port) except ImportError: print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas") except Exception as e: print(f"Error launching interactive viewer: {e}") def create_embedding_service( texts1: List[str], texts2: List[str], labels: tuple = ("Group1", "Group2"), port: int = 5055, host: str = "0.0.0.0", text_column: str = "text", model: str = "all-MiniLM-L6-v2", batch_size: int = 32, umap_args: Optional[Dict[str, Any]] = None ): """从两组文本数据创建embedding可视化服务""" import pathlib import embedding_atlas from embedding_atlas.projection import compute_text_projection from embedding_atlas.server import make_server from embedding_atlas.data_source import DataSource from embedding_atlas.utils import Hasher # 默认UMAP参数 if umap_args is None: umap_args = { "n_neighbors": 15, "min_dist": 0.1, "metric": "cosine" } # 1. 创建DataFrame df1 = pd.DataFrame({text_column: texts1, 'source': labels[0]}) df2 = pd.DataFrame({text_column: texts2, 'source': labels[1]}) combined_df = pd.concat([df1, df2], ignore_index=True) # 2. 添加必需的ID列 combined_df['_row_index'] = range(len(combined_df)) # 3. 计算embeddings和投影 try: compute_text_projection( combined_df, text=text_column, x="projection_x", y="projection_y", neighbors="__neighbors", # 改为双下划线 model=model, batch_size=batch_size, umap_args=umap_args ) except ImportError: print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas") return except Exception as e: print(f"Error computing projections: {e}") return # 4. 创建metadata (使用与CLI一致的props字段结构) metadata = { "props": { "data": { "id": "_row_index", "text": text_column, "projection": { "x": "projection_x", "y": "projection_y" }, "neighbors": "__neighbors" }, "initialState": {"version": "0.0.0"} } } # 5. 保存原始DataFrame用于DataSource df_for_datasource = combined_df.copy() # 6. 转换neighbors列为JSON字符串(仅用于保存parquet) if '__neighbors' in combined_df.columns: import json combined_df['__neighbors'] = combined_df['__neighbors'].apply( lambda x: json.dumps({ 'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']), 'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances']) }) if x is not None else None ) # 7. 保存parquet文件 parquet_path = "combined_dataset.parquet" combined_df.to_parquet(parquet_path, index=False) print(f"Dataset saved to {parquet_path}") # 8. 生成数据集标识符 hasher = Hasher() hasher.update(["custom_dataset"]) hasher.update(metadata) identifier = hasher.hexdigest() # 9. 使用原始DataFrame创建DataSource dataset = DataSource(identifier, df_for_datasource, metadata) # 10. 获取静态文件路径 static_path = str( (pathlib.Path(embedding_atlas.__file__).parent / "static").resolve() ) # 11. 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm") app = make_server(dataset, static_path=static_path) import uvicorn print(f"Starting interactive viewer on http://{host}:{port}") uvicorn.run(app, host=host, port=port) def visualize_csv_comparison( csv1_path: str, csv2_path: str, column1: str = "smiles", column2: str = "smiles", output_path: str = "comparison_visualization.png", label1: Optional[str] = None, label2: Optional[str] = None, launch_interactive: bool = False, port: int = 5055, host: str = "0.0.0.0", model: str = "all-MiniLM-L6-v2", batch_size: int = 32, umap_args: Optional[Dict[str, Any]] = None ): """ Visualize two CSV files with embedding-atlas and create a combined plot with different colors. Args: csv1_path: Path to the first CSV file csv2_path: Path to the second CSV file column1: Column name for the first CSV file (default: "smiles") column2: Column name for the second CSV file (default: "smiles") output_path: Output visualization file path label1: Label for the first dataset (default: filename) label2: Label for the second dataset (default: filename) launch_interactive: Whether to launch interactive viewer (default: False) port: Port for interactive viewer (default: 5055) host: Host for interactive viewer (default: "0.0.0.0") model: Embedding model to use (default: "all-MiniLM-L6-v2") batch_size: Batch size for embedding computation (default: 32) umap_args: UMAP arguments as dictionary (default: None) """ # Generate default labels from filenames if not provided if label1 is None: label1 = os.path.splitext(os.path.basename(csv1_path))[0] if label2 is None: label2 = os.path.splitext(os.path.basename(csv2_path))[0] # Read CSV files with error handling try: df1 = pd.read_csv(csv1_path) df2 = pd.read_csv(csv2_path) except FileNotFoundError as e: print(f"Error: CSV file not found - {e}") return except pd.errors.EmptyDataError: print("Error: One of the CSV files is empty") return except Exception as e: print(f"Error reading CSV files: {e}") return # Validate columns if column1 not in df1.columns: print(f"Error: Column '{column1}' not found in {csv1_path}") print(f"Available columns: {list(df1.columns)}") return if column2 not in df2.columns: print(f"Error: Column '{column2}' not found in {csv2_path}") print(f"Available columns: {list(df2.columns)}") return # Add source column to identify origin of each data point df1['source'] = label1 df2['source'] = label2 # Add column to identify which file the data came from df1['file_origin'] = 0 # First file df2['file_origin'] = 1 # Second file # Combine dataframes combined_df = pd.concat([df1, df2], ignore_index=True) # Add _row_index column for embedding-atlas DataSource (moved to earlier) combined_df['_row_index'] = range(len(combined_df)) # Determine which column to use for text projection text_column = column1 if column1 in combined_df.columns else column2 # Check if projection columns already exist has_projection = 'projection_x' in combined_df.columns and 'projection_y' in combined_df.columns if not has_projection: print("Computing embeddings and projections...") try: from embedding_atlas.projection import compute_text_projection # Use default UMAP args if not provided if umap_args is None: umap_args = { "n_neighbors": 15, "min_dist": 0.1, "metric": "cosine" } # Compute projections compute_text_projection( combined_df, text=text_column, x="projection_x", y="projection_y", neighbors="__neighbors", # 改为双下划线 model=model, batch_size=batch_size, umap_args=umap_args ) except ImportError: print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas") return except Exception as e: print(f"Error computing projections: {e}") return # Create visualization print("Creating visualization...") plt.figure(figsize=(12, 8)) # Plot data points with different colors based on file origin colors = ['blue', 'red'] for i, (label, color) in enumerate(zip([label1, label2], colors)): subset = combined_df[combined_df['file_origin'] == i] if not subset.empty: plt.scatter( subset['projection_x'], subset['projection_y'], c=color, label=label, alpha=0.6, s=30 ) plt.xlabel('Projection X') plt.ylabel('Projection Y') plt.title('Embedding Atlas Visualization Comparison') plt.legend(loc='upper right') plt.grid(True, alpha=0.3) # Save the plot plt.tight_layout() plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() print(f"Visualization saved to {output_path}") # Keep a copy of the original DataFrame for the interactive viewer # The viewer expects Python objects, not JSON strings df_for_viewer = combined_df.copy() # Convert __neighbors column to JSON strings before saving # This is needed because DuckDB-WASM expects JSON strings for complex data types if '__neighbors' in combined_df.columns: import json # 为保存到 parquet 文件的数据转换为 JSON 字符串 combined_df_for_save = combined_df.copy() combined_df_for_save['__neighbors'] = combined_df['__neighbors'].apply( lambda x: json.dumps({ 'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']), 'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances']) }) if x is not None else None ) # Save dataset for interactive viewing parquet_path = "combined_dataset.parquet" # 保存转换为 JSON 字符串的 DataFrame if '__neighbors' in combined_df.columns: combined_df_for_save.to_parquet(parquet_path, index=False) else: combined_df.to_parquet(parquet_path, index=False) print(f"Dataset saved to {parquet_path}") # Launch interactive viewer with original DataFrame (not converted to JSON) if launch_interactive: launch_interactive_viewer(df_for_viewer, text_column, port, host) print(f"Interactive viewer available at http://{host}:{port}") def main(): """Command-line interface for the visualization script.""" parser = argparse.ArgumentParser( description="Visualize two CSV files using embedding-atlas with different colors", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python -m visualization.comparison file1.csv file2.csv python -m visualization.comparison file1.csv file2.csv --column1 smiles --column2 SMILES python -m visualization.comparison file1.csv file2.csv --label1 "Dataset A" --label2 "Dataset B" python -m visualization.comparison file1.csv file2.csv --output comparison.png python -m visualization.comparison file1.csv file2.csv --interactive --port 8080 """ ) parser.add_argument("csv1", help="Path to the first CSV file") parser.add_argument("csv2", help="Path to the second CSV file") parser.add_argument("--column1", default="smiles", help="Column name for the first CSV file (default: smiles)") parser.add_argument("--column2", default="smiles", help="Column name for the second CSV file (default: smiles)") parser.add_argument("--output", "-o", default="comparison_visualization.png", help="Output visualization file path") parser.add_argument("--label1", help="Label for the first dataset (default: filename)") parser.add_argument("--label2", help="Label for the second dataset (default: filename)") parser.add_argument("--interactive", "-i", action="store_true", help="Launch interactive viewer") parser.add_argument("--port", "-p", type=int, default=5055, help="Port for interactive viewer (default: 5055)") parser.add_argument("--host", default="0.0.0.0", help="Host for interactive viewer (default: 0.0.0.0)") parser.add_argument("--model", default="all-MiniLM-L6-v2", help="Embedding model to use (default: all-MiniLM-L6-v2)") parser.add_argument("--batch-size", type=int, default=32, help="Batch size for embedding computation (default: 32)") args = parser.parse_args() # Parse UMAP arguments umap_args = { "n_neighbors": 15, "min_dist": 0.1, "metric": "cosine" } visualize_csv_comparison( args.csv1, args.csv2, args.column1, args.column2, args.output, args.label1, args.label2, args.interactive, args.port, args.host, args.model, args.batch_size, umap_args ) if __name__ == "__main__": main()