412 lines
15 KiB
Python
412 lines
15 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Script to visualize two CSV files using embedding-atlas with different colors for each file.
|
||
|
||
This script supports both command-line usage and API usage.
|
||
"""
|
||
|
||
import argparse
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import os
|
||
from typing import Optional, List, Dict, Any
|
||
import numpy as np
|
||
|
||
def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 5055, host: str = "0.0.0.0"):
|
||
"""使用Python API启动交互式服务器
|
||
|
||
Args:
|
||
df: 原始的pandas DataFrame,包含投影坐标和邻居信息(作为Python对象,而非JSON字符串)
|
||
text_column: 用作文本内容的列名
|
||
port: 启动服务器的端口
|
||
host: 服务器绑定的主机地址
|
||
"""
|
||
try:
|
||
from embedding_atlas.server import make_server
|
||
from embedding_atlas.data_source import DataSource
|
||
from embedding_atlas.utils import Hasher
|
||
import pathlib
|
||
|
||
# 创建metadata - 使用与CLI一致的props字段结构
|
||
metadata = {
|
||
"props": {
|
||
"data": {
|
||
"id": "_row_index",
|
||
"text": text_column,
|
||
"projection": {"x": "projection_x", "y": "projection_y"},
|
||
"neighbors": "__neighbors"
|
||
},
|
||
"initialState": {"version": "0.0.0"}
|
||
}
|
||
}
|
||
|
||
# 生成数据集标识符
|
||
hasher = Hasher()
|
||
hasher.update(["combined_dataset"])
|
||
hasher.update(metadata)
|
||
identifier = hasher.hexdigest()
|
||
|
||
# 创建DataSource - 直接使用传入的原始DataFrame
|
||
# 确保df['__neighbors'] 是 Python dict + numpy arrays 的结构,而不是 JSON 字符串
|
||
dataset = DataSource(identifier, df, metadata)
|
||
|
||
# 获取静态文件路径
|
||
import embedding_atlas
|
||
static_path = str(
|
||
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
||
)
|
||
|
||
# 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
|
||
app = make_server(dataset, static_path=static_path)
|
||
|
||
import uvicorn
|
||
print(f"Starting interactive viewer on http://{host}:{port}")
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
except ImportError:
|
||
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
|
||
except Exception as e:
|
||
print(f"Error launching interactive viewer: {e}")
|
||
|
||
def create_embedding_service(
|
||
texts1: List[str],
|
||
texts2: List[str],
|
||
labels: tuple = ("Group1", "Group2"),
|
||
port: int = 5055,
|
||
host: str = "0.0.0.0",
|
||
text_column: str = "text",
|
||
model: str = "all-MiniLM-L6-v2",
|
||
batch_size: int = 32,
|
||
umap_args: Optional[Dict[str, Any]] = None
|
||
):
|
||
"""从两组文本数据创建embedding可视化服务"""
|
||
import pathlib
|
||
import embedding_atlas
|
||
from embedding_atlas.projection import compute_text_projection
|
||
from embedding_atlas.server import make_server
|
||
from embedding_atlas.data_source import DataSource
|
||
from embedding_atlas.utils import Hasher
|
||
|
||
# 默认UMAP参数
|
||
if umap_args is None:
|
||
umap_args = {
|
||
"n_neighbors": 15,
|
||
"min_dist": 0.1,
|
||
"metric": "cosine"
|
||
}
|
||
|
||
# 1. 创建DataFrame
|
||
df1 = pd.DataFrame({text_column: texts1, 'source': labels[0]})
|
||
df2 = pd.DataFrame({text_column: texts2, 'source': labels[1]})
|
||
combined_df = pd.concat([df1, df2], ignore_index=True)
|
||
|
||
# 2. 添加必需的ID列
|
||
combined_df['_row_index'] = range(len(combined_df))
|
||
|
||
# 3. 计算embeddings和投影
|
||
try:
|
||
compute_text_projection(
|
||
combined_df,
|
||
text=text_column,
|
||
x="projection_x",
|
||
y="projection_y",
|
||
neighbors="__neighbors", # 改为双下划线
|
||
model=model,
|
||
batch_size=batch_size,
|
||
umap_args=umap_args
|
||
)
|
||
except ImportError:
|
||
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
|
||
return
|
||
except Exception as e:
|
||
print(f"Error computing projections: {e}")
|
||
return
|
||
|
||
# 4. 创建metadata (使用与CLI一致的props字段结构)
|
||
metadata = {
|
||
"props": {
|
||
"data": {
|
||
"id": "_row_index",
|
||
"text": text_column,
|
||
"projection": {
|
||
"x": "projection_x",
|
||
"y": "projection_y"
|
||
},
|
||
"neighbors": "__neighbors"
|
||
},
|
||
"initialState": {"version": "0.0.0"}
|
||
}
|
||
}
|
||
|
||
# 5. 保存原始DataFrame用于DataSource
|
||
df_for_datasource = combined_df.copy()
|
||
|
||
# 6. 转换neighbors列为JSON字符串(仅用于保存parquet)
|
||
if '__neighbors' in combined_df.columns:
|
||
import json
|
||
combined_df['__neighbors'] = combined_df['__neighbors'].apply(
|
||
lambda x: json.dumps({
|
||
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
|
||
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
|
||
}) if x is not None else None
|
||
)
|
||
|
||
# 7. 保存parquet文件
|
||
parquet_path = "combined_dataset.parquet"
|
||
combined_df.to_parquet(parquet_path, index=False)
|
||
print(f"Dataset saved to {parquet_path}")
|
||
|
||
# 8. 生成数据集标识符
|
||
hasher = Hasher()
|
||
hasher.update(["custom_dataset"])
|
||
hasher.update(metadata)
|
||
identifier = hasher.hexdigest()
|
||
|
||
# 9. 使用原始DataFrame创建DataSource
|
||
dataset = DataSource(identifier, df_for_datasource, metadata)
|
||
|
||
# 10. 获取静态文件路径
|
||
static_path = str(
|
||
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
||
)
|
||
|
||
# 11. 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
|
||
app = make_server(dataset, static_path=static_path)
|
||
import uvicorn
|
||
print(f"Starting interactive viewer on http://{host}:{port}")
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
def visualize_csv_comparison(
|
||
csv1_path: str,
|
||
csv2_path: str,
|
||
column1: str = "smiles",
|
||
column2: str = "smiles",
|
||
output_path: str = "comparison_visualization.png",
|
||
label1: Optional[str] = None,
|
||
label2: Optional[str] = None,
|
||
launch_interactive: bool = False,
|
||
port: int = 5055,
|
||
host: str = "0.0.0.0",
|
||
model: str = "all-MiniLM-L6-v2",
|
||
batch_size: int = 32,
|
||
umap_args: Optional[Dict[str, Any]] = None
|
||
):
|
||
"""
|
||
Visualize two CSV files with embedding-atlas and create a combined plot with different colors.
|
||
|
||
Args:
|
||
csv1_path: Path to the first CSV file
|
||
csv2_path: Path to the second CSV file
|
||
column1: Column name for the first CSV file (default: "smiles")
|
||
column2: Column name for the second CSV file (default: "smiles")
|
||
output_path: Output visualization file path
|
||
label1: Label for the first dataset (default: filename)
|
||
label2: Label for the second dataset (default: filename)
|
||
launch_interactive: Whether to launch interactive viewer (default: False)
|
||
port: Port for interactive viewer (default: 5055)
|
||
host: Host for interactive viewer (default: "0.0.0.0")
|
||
model: Embedding model to use (default: "all-MiniLM-L6-v2")
|
||
batch_size: Batch size for embedding computation (default: 32)
|
||
umap_args: UMAP arguments as dictionary (default: None)
|
||
"""
|
||
|
||
# Generate default labels from filenames if not provided
|
||
if label1 is None:
|
||
label1 = os.path.splitext(os.path.basename(csv1_path))[0]
|
||
if label2 is None:
|
||
label2 = os.path.splitext(os.path.basename(csv2_path))[0]
|
||
|
||
# Read CSV files with error handling
|
||
try:
|
||
df1 = pd.read_csv(csv1_path)
|
||
df2 = pd.read_csv(csv2_path)
|
||
except FileNotFoundError as e:
|
||
print(f"Error: CSV file not found - {e}")
|
||
return
|
||
except pd.errors.EmptyDataError:
|
||
print("Error: One of the CSV files is empty")
|
||
return
|
||
except Exception as e:
|
||
print(f"Error reading CSV files: {e}")
|
||
return
|
||
|
||
# Validate columns
|
||
if column1 not in df1.columns:
|
||
print(f"Error: Column '{column1}' not found in {csv1_path}")
|
||
print(f"Available columns: {list(df1.columns)}")
|
||
return
|
||
if column2 not in df2.columns:
|
||
print(f"Error: Column '{column2}' not found in {csv2_path}")
|
||
print(f"Available columns: {list(df2.columns)}")
|
||
return
|
||
|
||
# Add source column to identify origin of each data point
|
||
df1['source'] = label1
|
||
df2['source'] = label2
|
||
|
||
# Add column to identify which file the data came from
|
||
df1['file_origin'] = 0 # First file
|
||
df2['file_origin'] = 1 # Second file
|
||
|
||
# Combine dataframes
|
||
combined_df = pd.concat([df1, df2], ignore_index=True)
|
||
|
||
# Add _row_index column for embedding-atlas DataSource (moved to earlier)
|
||
combined_df['_row_index'] = range(len(combined_df))
|
||
|
||
# Determine which column to use for text projection
|
||
text_column = column1 if column1 in combined_df.columns else column2
|
||
|
||
# Check if projection columns already exist
|
||
has_projection = 'projection_x' in combined_df.columns and 'projection_y' in combined_df.columns
|
||
|
||
if not has_projection:
|
||
print("Computing embeddings and projections...")
|
||
try:
|
||
from embedding_atlas.projection import compute_text_projection
|
||
|
||
# Use default UMAP args if not provided
|
||
if umap_args is None:
|
||
umap_args = {
|
||
"n_neighbors": 15,
|
||
"min_dist": 0.1,
|
||
"metric": "cosine"
|
||
}
|
||
|
||
# Compute projections
|
||
compute_text_projection(
|
||
combined_df,
|
||
text=text_column,
|
||
x="projection_x",
|
||
y="projection_y",
|
||
neighbors="__neighbors", # 改为双下划线
|
||
model=model,
|
||
batch_size=batch_size,
|
||
umap_args=umap_args
|
||
)
|
||
except ImportError:
|
||
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
|
||
return
|
||
except Exception as e:
|
||
print(f"Error computing projections: {e}")
|
||
return
|
||
|
||
# Create visualization
|
||
print("Creating visualization...")
|
||
plt.figure(figsize=(12, 8))
|
||
|
||
# Plot data points with different colors based on file origin
|
||
colors = ['blue', 'red']
|
||
for i, (label, color) in enumerate(zip([label1, label2], colors)):
|
||
subset = combined_df[combined_df['file_origin'] == i]
|
||
if not subset.empty:
|
||
plt.scatter(
|
||
subset['projection_x'],
|
||
subset['projection_y'],
|
||
c=color,
|
||
label=label,
|
||
alpha=0.6,
|
||
s=30
|
||
)
|
||
|
||
plt.xlabel('Projection X')
|
||
plt.ylabel('Projection Y')
|
||
plt.title('Embedding Atlas Visualization Comparison')
|
||
plt.legend(loc='upper right')
|
||
plt.grid(True, alpha=0.3)
|
||
|
||
# Save the plot
|
||
plt.tight_layout()
|
||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
print(f"Visualization saved to {output_path}")
|
||
|
||
# Keep a copy of the original DataFrame for the interactive viewer
|
||
# The viewer expects Python objects, not JSON strings
|
||
df_for_viewer = combined_df.copy()
|
||
|
||
# Convert __neighbors column to JSON strings before saving
|
||
# This is needed because DuckDB-WASM expects JSON strings for complex data types
|
||
if '__neighbors' in combined_df.columns:
|
||
import json
|
||
# 为保存到 parquet 文件的数据转换为 JSON 字符串
|
||
combined_df_for_save = combined_df.copy()
|
||
combined_df_for_save['__neighbors'] = combined_df['__neighbors'].apply(
|
||
lambda x: json.dumps({
|
||
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
|
||
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
|
||
}) if x is not None else None
|
||
)
|
||
|
||
# Save dataset for interactive viewing
|
||
parquet_path = "combined_dataset.parquet"
|
||
# 保存转换为 JSON 字符串的 DataFrame
|
||
if '__neighbors' in combined_df.columns:
|
||
combined_df_for_save.to_parquet(parquet_path, index=False)
|
||
else:
|
||
combined_df.to_parquet(parquet_path, index=False)
|
||
print(f"Dataset saved to {parquet_path}")
|
||
|
||
# Launch interactive viewer with original DataFrame (not converted to JSON)
|
||
if launch_interactive:
|
||
launch_interactive_viewer(df_for_viewer, text_column, port, host)
|
||
print(f"Interactive viewer available at http://{host}:{port}")
|
||
|
||
|
||
def main():
|
||
"""Command-line interface for the visualization script."""
|
||
parser = argparse.ArgumentParser(
|
||
description="Visualize two CSV files using embedding-atlas with different colors",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
Examples:
|
||
python -m visualization.comparison file1.csv file2.csv
|
||
python -m visualization.comparison file1.csv file2.csv --column1 smiles --column2 SMILES
|
||
python -m visualization.comparison file1.csv file2.csv --label1 "Dataset A" --label2 "Dataset B"
|
||
python -m visualization.comparison file1.csv file2.csv --output comparison.png
|
||
python -m visualization.comparison file1.csv file2.csv --interactive --port 8080
|
||
"""
|
||
)
|
||
|
||
parser.add_argument("csv1", help="Path to the first CSV file")
|
||
parser.add_argument("csv2", help="Path to the second CSV file")
|
||
parser.add_argument("--column1", default="smiles", help="Column name for the first CSV file (default: smiles)")
|
||
parser.add_argument("--column2", default="smiles", help="Column name for the second CSV file (default: smiles)")
|
||
parser.add_argument("--output", "-o", default="comparison_visualization.png", help="Output visualization file path")
|
||
parser.add_argument("--label1", help="Label for the first dataset (default: filename)")
|
||
parser.add_argument("--label2", help="Label for the second dataset (default: filename)")
|
||
parser.add_argument("--interactive", "-i", action="store_true", help="Launch interactive viewer")
|
||
parser.add_argument("--port", "-p", type=int, default=5055, help="Port for interactive viewer (default: 5055)")
|
||
parser.add_argument("--host", default="0.0.0.0", help="Host for interactive viewer (default: 0.0.0.0)")
|
||
parser.add_argument("--model", default="all-MiniLM-L6-v2", help="Embedding model to use (default: all-MiniLM-L6-v2)")
|
||
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for embedding computation (default: 32)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Parse UMAP arguments
|
||
umap_args = {
|
||
"n_neighbors": 15,
|
||
"min_dist": 0.1,
|
||
"metric": "cosine"
|
||
}
|
||
|
||
visualize_csv_comparison(
|
||
args.csv1,
|
||
args.csv2,
|
||
args.column1,
|
||
args.column2,
|
||
args.output,
|
||
args.label1,
|
||
args.label2,
|
||
args.interactive,
|
||
args.port,
|
||
args.host,
|
||
args.model,
|
||
args.batch_size,
|
||
umap_args
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |