Files
embedding_atlas/src/visualization/comparison.py

412 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Script to visualize two CSV files using embedding-atlas with different colors for each file.
This script supports both command-line usage and API usage.
"""
import argparse
import pandas as pd
import matplotlib.pyplot as plt
import os
from typing import Optional, List, Dict, Any
import numpy as np
def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 5055, host: str = "0.0.0.0"):
"""使用Python API启动交互式服务器
Args:
df: 原始的pandas DataFrame包含投影坐标和邻居信息作为Python对象而非JSON字符串
text_column: 用作文本内容的列名
port: 启动服务器的端口
host: 服务器绑定的主机地址
"""
try:
from embedding_atlas.server import make_server
from embedding_atlas.data_source import DataSource
from embedding_atlas.utils import Hasher
import pathlib
# 创建metadata - 使用与CLI一致的props字段结构
metadata = {
"props": {
"data": {
"id": "_row_index",
"text": text_column,
"projection": {"x": "projection_x", "y": "projection_y"},
"neighbors": "__neighbors"
},
"initialState": {"version": "0.0.0"}
}
}
# 生成数据集标识符
hasher = Hasher()
hasher.update(["combined_dataset"])
hasher.update(metadata)
identifier = hasher.hexdigest()
# 创建DataSource - 直接使用传入的原始DataFrame
# 确保df['__neighbors'] 是 Python dict + numpy arrays 的结构,而不是 JSON 字符串
dataset = DataSource(identifier, df, metadata)
# 获取静态文件路径
import embedding_atlas
static_path = str(
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
)
# 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
app = make_server(dataset, static_path=static_path)
import uvicorn
print(f"Starting interactive viewer on http://{host}:{port}")
uvicorn.run(app, host=host, port=port)
except ImportError:
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
except Exception as e:
print(f"Error launching interactive viewer: {e}")
def create_embedding_service(
texts1: List[str],
texts2: List[str],
labels: tuple = ("Group1", "Group2"),
port: int = 5055,
host: str = "0.0.0.0",
text_column: str = "text",
model: str = "all-MiniLM-L6-v2",
batch_size: int = 32,
umap_args: Optional[Dict[str, Any]] = None
):
"""从两组文本数据创建embedding可视化服务"""
import pathlib
import embedding_atlas
from embedding_atlas.projection import compute_text_projection
from embedding_atlas.server import make_server
from embedding_atlas.data_source import DataSource
from embedding_atlas.utils import Hasher
# 默认UMAP参数
if umap_args is None:
umap_args = {
"n_neighbors": 15,
"min_dist": 0.1,
"metric": "cosine"
}
# 1. 创建DataFrame
df1 = pd.DataFrame({text_column: texts1, 'source': labels[0]})
df2 = pd.DataFrame({text_column: texts2, 'source': labels[1]})
combined_df = pd.concat([df1, df2], ignore_index=True)
# 2. 添加必需的ID列
combined_df['_row_index'] = range(len(combined_df))
# 3. 计算embeddings和投影
try:
compute_text_projection(
combined_df,
text=text_column,
x="projection_x",
y="projection_y",
neighbors="__neighbors", # 改为双下划线
model=model,
batch_size=batch_size,
umap_args=umap_args
)
except ImportError:
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
return
except Exception as e:
print(f"Error computing projections: {e}")
return
# 4. 创建metadata (使用与CLI一致的props字段结构)
metadata = {
"props": {
"data": {
"id": "_row_index",
"text": text_column,
"projection": {
"x": "projection_x",
"y": "projection_y"
},
"neighbors": "__neighbors"
},
"initialState": {"version": "0.0.0"}
}
}
# 5. 保存原始DataFrame用于DataSource
df_for_datasource = combined_df.copy()
# 6. 转换neighbors列为JSON字符串(仅用于保存parquet)
if '__neighbors' in combined_df.columns:
import json
combined_df['__neighbors'] = combined_df['__neighbors'].apply(
lambda x: json.dumps({
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
}) if x is not None else None
)
# 7. 保存parquet文件
parquet_path = "combined_dataset.parquet"
combined_df.to_parquet(parquet_path, index=False)
print(f"Dataset saved to {parquet_path}")
# 8. 生成数据集标识符
hasher = Hasher()
hasher.update(["custom_dataset"])
hasher.update(metadata)
identifier = hasher.hexdigest()
# 9. 使用原始DataFrame创建DataSource
dataset = DataSource(identifier, df_for_datasource, metadata)
# 10. 获取静态文件路径
static_path = str(
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
)
# 11. 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
app = make_server(dataset, static_path=static_path)
import uvicorn
print(f"Starting interactive viewer on http://{host}:{port}")
uvicorn.run(app, host=host, port=port)
def visualize_csv_comparison(
csv1_path: str,
csv2_path: str,
column1: str = "smiles",
column2: str = "smiles",
output_path: str = "comparison_visualization.png",
label1: Optional[str] = None,
label2: Optional[str] = None,
launch_interactive: bool = False,
port: int = 5055,
host: str = "0.0.0.0",
model: str = "all-MiniLM-L6-v2",
batch_size: int = 32,
umap_args: Optional[Dict[str, Any]] = None
):
"""
Visualize two CSV files with embedding-atlas and create a combined plot with different colors.
Args:
csv1_path: Path to the first CSV file
csv2_path: Path to the second CSV file
column1: Column name for the first CSV file (default: "smiles")
column2: Column name for the second CSV file (default: "smiles")
output_path: Output visualization file path
label1: Label for the first dataset (default: filename)
label2: Label for the second dataset (default: filename)
launch_interactive: Whether to launch interactive viewer (default: False)
port: Port for interactive viewer (default: 5055)
host: Host for interactive viewer (default: "0.0.0.0")
model: Embedding model to use (default: "all-MiniLM-L6-v2")
batch_size: Batch size for embedding computation (default: 32)
umap_args: UMAP arguments as dictionary (default: None)
"""
# Generate default labels from filenames if not provided
if label1 is None:
label1 = os.path.splitext(os.path.basename(csv1_path))[0]
if label2 is None:
label2 = os.path.splitext(os.path.basename(csv2_path))[0]
# Read CSV files with error handling
try:
df1 = pd.read_csv(csv1_path)
df2 = pd.read_csv(csv2_path)
except FileNotFoundError as e:
print(f"Error: CSV file not found - {e}")
return
except pd.errors.EmptyDataError:
print("Error: One of the CSV files is empty")
return
except Exception as e:
print(f"Error reading CSV files: {e}")
return
# Validate columns
if column1 not in df1.columns:
print(f"Error: Column '{column1}' not found in {csv1_path}")
print(f"Available columns: {list(df1.columns)}")
return
if column2 not in df2.columns:
print(f"Error: Column '{column2}' not found in {csv2_path}")
print(f"Available columns: {list(df2.columns)}")
return
# Add source column to identify origin of each data point
df1['source'] = label1
df2['source'] = label2
# Add column to identify which file the data came from
df1['file_origin'] = 0 # First file
df2['file_origin'] = 1 # Second file
# Combine dataframes
combined_df = pd.concat([df1, df2], ignore_index=True)
# Add _row_index column for embedding-atlas DataSource (moved to earlier)
combined_df['_row_index'] = range(len(combined_df))
# Determine which column to use for text projection
text_column = column1 if column1 in combined_df.columns else column2
# Check if projection columns already exist
has_projection = 'projection_x' in combined_df.columns and 'projection_y' in combined_df.columns
if not has_projection:
print("Computing embeddings and projections...")
try:
from embedding_atlas.projection import compute_text_projection
# Use default UMAP args if not provided
if umap_args is None:
umap_args = {
"n_neighbors": 15,
"min_dist": 0.1,
"metric": "cosine"
}
# Compute projections
compute_text_projection(
combined_df,
text=text_column,
x="projection_x",
y="projection_y",
neighbors="__neighbors", # 改为双下划线
model=model,
batch_size=batch_size,
umap_args=umap_args
)
except ImportError:
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
return
except Exception as e:
print(f"Error computing projections: {e}")
return
# Create visualization
print("Creating visualization...")
plt.figure(figsize=(12, 8))
# Plot data points with different colors based on file origin
colors = ['blue', 'red']
for i, (label, color) in enumerate(zip([label1, label2], colors)):
subset = combined_df[combined_df['file_origin'] == i]
if not subset.empty:
plt.scatter(
subset['projection_x'],
subset['projection_y'],
c=color,
label=label,
alpha=0.6,
s=30
)
plt.xlabel('Projection X')
plt.ylabel('Projection Y')
plt.title('Embedding Atlas Visualization Comparison')
plt.legend(loc='upper right')
plt.grid(True, alpha=0.3)
# Save the plot
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Visualization saved to {output_path}")
# Keep a copy of the original DataFrame for the interactive viewer
# The viewer expects Python objects, not JSON strings
df_for_viewer = combined_df.copy()
# Convert __neighbors column to JSON strings before saving
# This is needed because DuckDB-WASM expects JSON strings for complex data types
if '__neighbors' in combined_df.columns:
import json
# 为保存到 parquet 文件的数据转换为 JSON 字符串
combined_df_for_save = combined_df.copy()
combined_df_for_save['__neighbors'] = combined_df['__neighbors'].apply(
lambda x: json.dumps({
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
}) if x is not None else None
)
# Save dataset for interactive viewing
parquet_path = "combined_dataset.parquet"
# 保存转换为 JSON 字符串的 DataFrame
if '__neighbors' in combined_df.columns:
combined_df_for_save.to_parquet(parquet_path, index=False)
else:
combined_df.to_parquet(parquet_path, index=False)
print(f"Dataset saved to {parquet_path}")
# Launch interactive viewer with original DataFrame (not converted to JSON)
if launch_interactive:
launch_interactive_viewer(df_for_viewer, text_column, port, host)
print(f"Interactive viewer available at http://{host}:{port}")
def main():
"""Command-line interface for the visualization script."""
parser = argparse.ArgumentParser(
description="Visualize two CSV files using embedding-atlas with different colors",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m visualization.comparison file1.csv file2.csv
python -m visualization.comparison file1.csv file2.csv --column1 smiles --column2 SMILES
python -m visualization.comparison file1.csv file2.csv --label1 "Dataset A" --label2 "Dataset B"
python -m visualization.comparison file1.csv file2.csv --output comparison.png
python -m visualization.comparison file1.csv file2.csv --interactive --port 8080
"""
)
parser.add_argument("csv1", help="Path to the first CSV file")
parser.add_argument("csv2", help="Path to the second CSV file")
parser.add_argument("--column1", default="smiles", help="Column name for the first CSV file (default: smiles)")
parser.add_argument("--column2", default="smiles", help="Column name for the second CSV file (default: smiles)")
parser.add_argument("--output", "-o", default="comparison_visualization.png", help="Output visualization file path")
parser.add_argument("--label1", help="Label for the first dataset (default: filename)")
parser.add_argument("--label2", help="Label for the second dataset (default: filename)")
parser.add_argument("--interactive", "-i", action="store_true", help="Launch interactive viewer")
parser.add_argument("--port", "-p", type=int, default=5055, help="Port for interactive viewer (default: 5055)")
parser.add_argument("--host", default="0.0.0.0", help="Host for interactive viewer (default: 0.0.0.0)")
parser.add_argument("--model", default="all-MiniLM-L6-v2", help="Embedding model to use (default: all-MiniLM-L6-v2)")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for embedding computation (default: 32)")
args = parser.parse_args()
# Parse UMAP arguments
umap_args = {
"n_neighbors": 15,
"min_dist": 0.1,
"metric": "cosine"
}
visualize_csv_comparison(
args.csv1,
args.csv2,
args.column1,
args.column2,
args.output,
args.label1,
args.label2,
args.interactive,
args.port,
args.host,
args.model,
args.batch_size,
umap_args
)
if __name__ == "__main__":
main()