重构项目结构并更新README.md

1. 重构目录结构:
   - 创建src/visualization模块用于存放可视化相关功能
   - 移动script/visualize_csv_comparison.py到src/visualization/comparison.py
   - 创建src/visualization/__init__.py导出主要函数
   - 整理script目录,按功能分类存放脚本文件

2. 更新README.md:
   - 添加CSV文件比较可视化部分
   - 提供Python API和命令行使用方法说明
   - 描述功能特点和使用示例

3. 更新模块引用:
   - 修正comparison.py中的模块引用路径
   - 更新命令行帮助信息中的使用示例
This commit is contained in:
2025-10-23 17:55:36 +08:00
parent 9f0a0fbcdc
commit bbf1746046
7 changed files with 358 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
"""Visualization module for embedding-atlas comparisons."""
from .comparison import visualize_csv_comparison, create_embedding_service, launch_interactive_viewer
__all__ = [
"visualize_csv_comparison",
"create_embedding_service",
"launch_interactive_viewer"
]

View File

@@ -0,0 +1,349 @@
#!/usr/bin/env python3
"""
Script to visualize two CSV files using embedding-atlas with different colors for each file.
This script supports both command-line usage and API usage.
"""
import argparse
import pandas as pd
import matplotlib.pyplot as plt
import os
from typing import Optional, List, Dict, Any
import numpy as np
def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 5055, host: str = "localhost"):
"""使用Python API启动交互式服务器"""
try:
from embedding_atlas.server import make_server
from embedding_atlas.data_source import DataSource
from embedding_atlas.utils import Hasher
import pathlib
# 创建metadata
metadata = {
"columns": {
"id": "_row_index",
"text": text_column,
"embedding": {"x": "projection_x", "y": "projection_y"},
"neighbors": "neighbors"
}
}
# 生成数据集标识符
hasher = Hasher()
hasher.update(["combined_dataset"])
hasher.update(metadata)
identifier = hasher.hexdigest()
# 创建DataSource
dataset = DataSource(identifier, df, metadata)
# 获取静态文件路径
import embedding_atlas
static_path = str(
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
)
# 创建并启动服务器
app = make_server(dataset, static_path=static_path)
import uvicorn
print(f"Starting interactive viewer on http://{host}:{port}")
uvicorn.run(app, host=host, port=port)
except ImportError:
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
except Exception as e:
print(f"Error launching interactive viewer: {e}")
def create_embedding_service(
texts1: List[str],
texts2: List[str],
labels: tuple = ("Group1", "Group2"),
port: int = 5055,
host: str = "localhost",
text_column: str = "text",
model: str = "all-MiniLM-L6-v2",
batch_size: int = 32,
umap_args: Optional[Dict[str, Any]] = None
):
"""
从两组文本数据创建embedding可视化服务
Args:
texts1: 第一组文本列表
texts2: 第二组文本列表
labels: 两组数据的标签
port: 服务器端口
host: 服务器主机
text_column: 文本列名称
model: 使用的嵌入模型
batch_size: 批处理大小
umap_args: UMAP参数字典
"""
import pathlib
import embedding_atlas
# 默认UMAP参数
if umap_args is None:
umap_args = {
"n_neighbors": 15,
"min_dist": 0.1,
"metric": "cosine"
}
# 1. 创建DataFrame
df1 = pd.DataFrame({text_column: texts1, 'source': labels[0]})
df2 = pd.DataFrame({text_column: texts2, 'source': labels[1]})
combined_df = pd.concat([df1, df2], ignore_index=True)
# 2. 添加必需的ID列
combined_df['_row_index'] = range(len(combined_df))
# 3. 计算embeddings和投影
try:
from embedding_atlas.projection import compute_text_projection
compute_text_projection(
combined_df,
text=text_column,
x="projection_x",
y="projection_y",
neighbors="neighbors",
model=model,
batch_size=batch_size,
umap_args=umap_args
)
except ImportError:
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
return
except Exception as e:
print(f"Error computing projections: {e}")
return
# 4. 创建metadata
metadata = {
"columns": {
"id": "_row_index",
"text": text_column,
"embedding": {
"x": "projection_x",
"y": "projection_y"
},
"neighbors": "neighbors"
}
}
# 5. 生成数据集标识符
hasher = Hasher()
hasher.update(["custom_dataset"])
hasher.update(metadata)
identifier = hasher.hexdigest()
# 6. 创建DataSource
dataset = DataSource(identifier, combined_df, metadata)
# 7. 获取静态文件路径
static_path = str(
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
)
# 8. 创建并启动服务器
app = make_server(dataset, static_path=static_path)
import uvicorn
print(f"Starting interactive viewer on http://{host}:{port}")
uvicorn.run(app, host=host, port=port)
def visualize_csv_comparison(
csv1_path: str,
csv2_path: str,
column1: str = "smiles",
column2: str = "smiles",
output_path: str = "comparison_visualization.png",
label1: Optional[str] = None,
label2: Optional[str] = None,
launch_interactive: bool = False,
port: int = 5055,
host: str = "localhost",
model: str = "all-MiniLM-L6-v2",
batch_size: int = 32,
umap_args: Optional[Dict[str, Any]] = None
):
"""
Visualize two CSV files with embedding-atlas and create a combined plot with different colors.
Args:
csv1_path: Path to the first CSV file
csv2_path: Path to the second CSV file
column1: Column name for the first CSV file (default: "smiles")
column2: Column name for the second CSV file (default: "smiles")
output_path: Output visualization file path
label1: Label for the first dataset (default: filename)
label2: Label for the second dataset (default: filename)
launch_interactive: Whether to launch interactive viewer (default: False)
port: Port for interactive viewer (default: 5055)
host: Host for interactive viewer (default: "localhost")
model: Embedding model to use (default: "all-MiniLM-L6-v2")
batch_size: Batch size for embedding computation (default: 32)
umap_args: UMAP arguments as dictionary (default: None)
"""
# Generate default labels from filenames if not provided
if label1 is None:
label1 = os.path.splitext(os.path.basename(csv1_path))[0]
if label2 is None:
label2 = os.path.splitext(os.path.basename(csv2_path))[0]
# Read CSV files
df1 = pd.read_csv(csv1_path)
df2 = pd.read_csv(csv2_path)
# Add source column to identify origin of each data point
df1['source'] = label1
df2['source'] = label2
# Add column to identify which file the data came from
df1['file_origin'] = 0 # First file
df2['file_origin'] = 1 # Second file
# Combine dataframes
combined_df = pd.concat([df1, df2], ignore_index=True)
# Determine which column to use for text projection (moved before projection check)
text_column = column1 if column1 in combined_df.columns else column2
# Check if projection columns already exist
has_projection = 'projection_x' in combined_df.columns and 'projection_y' in combined_df.columns
if not has_projection:
print("Computing embeddings and projections...")
try:
from embedding_atlas.projection import compute_text_projection
# Use default UMAP args if not provided
if umap_args is None:
umap_args = {
"n_neighbors": 15,
"min_dist": 0.1,
"metric": "cosine"
}
# Compute projections
compute_text_projection(
combined_df,
text=text_column,
x="projection_x",
y="projection_y",
neighbors="neighbors",
model=model,
batch_size=batch_size,
umap_args=umap_args
)
except ImportError:
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
return
except Exception as e:
print(f"Error computing projections: {e}")
return
# Create visualization
print("Creating visualization...")
plt.figure(figsize=(12, 8))
# Plot data points with different colors based on file origin
colors = ['blue', 'red']
for i, (label, color) in enumerate(zip([label1, label2], colors)):
subset = combined_df[combined_df['file_origin'] == i]
if not subset.empty:
plt.scatter(
subset['projection_x'],
subset['projection_y'],
c=color,
label=label,
alpha=0.6,
s=30
)
plt.xlabel('Projection X')
plt.ylabel('Projection Y')
plt.title('Embedding Atlas Visualization Comparison')
plt.legend(loc='upper right')
plt.grid(True, alpha=0.3)
# Save the plot
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Visualization saved to {output_path}")
# Add _row_index column for embedding-atlas DataSource
combined_df['_row_index'] = range(len(combined_df))
# Save dataset for interactive viewing
parquet_path = "combined_dataset.parquet"
combined_df.to_parquet(parquet_path, index=False)
print(f"Dataset saved to {parquet_path}")
# Launch interactive viewer
if launch_interactive:
launch_interactive_viewer(combined_df, text_column, port, host)
print(f"Interactive viewer available at http://{host}:{port}")
def main():
"""Command-line interface for the visualization script."""
parser = argparse.ArgumentParser(
description="Visualize two CSV files using embedding-atlas with different colors",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python visualize_csv_comparison.py file1.csv file2.csv
python visualize_csv_comparison.py file1.csv file2.csv --column1 smiles --column2 SMILES
python visualize_csv_comparison.py file1.csv file2.csv --label1 "Dataset A" --label2 "Dataset B"
python visualize_csv_comparison.py file1.csv file2.csv --output comparison.png
python visualize_csv_comparison.py file1.csv file2.csv --interactive --port 8080
"""
)
parser.add_argument("csv1", help="Path to the first CSV file")
parser.add_argument("csv2", help="Path to the second CSV file")
parser.add_argument("--column1", default="smiles", help="Column name for the first CSV file (default: smiles)")
parser.add_argument("--column2", default="smiles", help="Column name for the second CSV file (default: smiles)")
parser.add_argument("--output", "-o", default="comparison_visualization.png", help="Output visualization file path")
parser.add_argument("--label1", help="Label for the first dataset (default: filename)")
parser.add_argument("--label2", help="Label for the second dataset (default: filename)")
parser.add_argument("--interactive", "-i", action="store_true", help="Launch interactive viewer")
parser.add_argument("--port", "-p", type=int, default=5055, help="Port for interactive viewer (default: 5055)")
parser.add_argument("--host", default="localhost", help="Host for interactive viewer (default: localhost)")
parser.add_argument("--model", default="all-MiniLM-L6-v2", help="Embedding model to use (default: all-MiniLM-L6-v2)")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for embedding computation (default: 32)")
args = parser.parse_args()
# Parse UMAP arguments
umap_args = {
"n_neighbors": 15,
"min_dist": 0.1,
"metric": "cosine"
}
visualize_csv_comparison(
args.csv1,
args.csv2,
args.column1,
args.column2,
args.output,
args.label1,
args.label2,
args.interactive,
args.port,
args.host,
args.model,
args.batch_size,
umap_args
)
if __name__ == "__main__":
main()