重构项目结构并更新README.md
1. 重构目录结构: - 创建src/visualization模块用于存放可视化相关功能 - 移动script/visualize_csv_comparison.py到src/visualization/comparison.py - 创建src/visualization/__init__.py导出主要函数 - 整理script目录,按功能分类存放脚本文件 2. 更新README.md: - 添加CSV文件比较可视化部分 - 提供Python API和命令行使用方法说明 - 描述功能特点和使用示例 3. 更新模块引用: - 修正comparison.py中的模块引用路径 - 更新命令行帮助信息中的使用示例
This commit is contained in:
9
src/visualization/__init__.py
Normal file
9
src/visualization/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Visualization module for embedding-atlas comparisons."""
|
||||
|
||||
from .comparison import visualize_csv_comparison, create_embedding_service, launch_interactive_viewer
|
||||
|
||||
__all__ = [
|
||||
"visualize_csv_comparison",
|
||||
"create_embedding_service",
|
||||
"launch_interactive_viewer"
|
||||
]
|
||||
349
src/visualization/comparison.py
Normal file
349
src/visualization/comparison.py
Normal file
@@ -0,0 +1,349 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to visualize two CSV files using embedding-atlas with different colors for each file.
|
||||
|
||||
This script supports both command-line usage and API usage.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any
|
||||
import numpy as np
|
||||
|
||||
def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 5055, host: str = "localhost"):
|
||||
"""使用Python API启动交互式服务器"""
|
||||
try:
|
||||
from embedding_atlas.server import make_server
|
||||
from embedding_atlas.data_source import DataSource
|
||||
from embedding_atlas.utils import Hasher
|
||||
import pathlib
|
||||
|
||||
# 创建metadata
|
||||
metadata = {
|
||||
"columns": {
|
||||
"id": "_row_index",
|
||||
"text": text_column,
|
||||
"embedding": {"x": "projection_x", "y": "projection_y"},
|
||||
"neighbors": "neighbors"
|
||||
}
|
||||
}
|
||||
|
||||
# 生成数据集标识符
|
||||
hasher = Hasher()
|
||||
hasher.update(["combined_dataset"])
|
||||
hasher.update(metadata)
|
||||
identifier = hasher.hexdigest()
|
||||
|
||||
# 创建DataSource
|
||||
dataset = DataSource(identifier, df, metadata)
|
||||
|
||||
# 获取静态文件路径
|
||||
import embedding_atlas
|
||||
static_path = str(
|
||||
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
||||
)
|
||||
|
||||
# 创建并启动服务器
|
||||
app = make_server(dataset, static_path=static_path)
|
||||
|
||||
import uvicorn
|
||||
print(f"Starting interactive viewer on http://{host}:{port}")
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
except ImportError:
|
||||
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
|
||||
except Exception as e:
|
||||
print(f"Error launching interactive viewer: {e}")
|
||||
|
||||
def create_embedding_service(
|
||||
texts1: List[str],
|
||||
texts2: List[str],
|
||||
labels: tuple = ("Group1", "Group2"),
|
||||
port: int = 5055,
|
||||
host: str = "localhost",
|
||||
text_column: str = "text",
|
||||
model: str = "all-MiniLM-L6-v2",
|
||||
batch_size: int = 32,
|
||||
umap_args: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
从两组文本数据创建embedding可视化服务
|
||||
|
||||
Args:
|
||||
texts1: 第一组文本列表
|
||||
texts2: 第二组文本列表
|
||||
labels: 两组数据的标签
|
||||
port: 服务器端口
|
||||
host: 服务器主机
|
||||
text_column: 文本列名称
|
||||
model: 使用的嵌入模型
|
||||
batch_size: 批处理大小
|
||||
umap_args: UMAP参数字典
|
||||
"""
|
||||
import pathlib
|
||||
import embedding_atlas
|
||||
|
||||
# 默认UMAP参数
|
||||
if umap_args is None:
|
||||
umap_args = {
|
||||
"n_neighbors": 15,
|
||||
"min_dist": 0.1,
|
||||
"metric": "cosine"
|
||||
}
|
||||
|
||||
# 1. 创建DataFrame
|
||||
df1 = pd.DataFrame({text_column: texts1, 'source': labels[0]})
|
||||
df2 = pd.DataFrame({text_column: texts2, 'source': labels[1]})
|
||||
combined_df = pd.concat([df1, df2], ignore_index=True)
|
||||
|
||||
# 2. 添加必需的ID列
|
||||
combined_df['_row_index'] = range(len(combined_df))
|
||||
|
||||
# 3. 计算embeddings和投影
|
||||
try:
|
||||
from embedding_atlas.projection import compute_text_projection
|
||||
compute_text_projection(
|
||||
combined_df,
|
||||
text=text_column,
|
||||
x="projection_x",
|
||||
y="projection_y",
|
||||
neighbors="neighbors",
|
||||
model=model,
|
||||
batch_size=batch_size,
|
||||
umap_args=umap_args
|
||||
)
|
||||
except ImportError:
|
||||
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error computing projections: {e}")
|
||||
return
|
||||
|
||||
# 4. 创建metadata
|
||||
metadata = {
|
||||
"columns": {
|
||||
"id": "_row_index",
|
||||
"text": text_column,
|
||||
"embedding": {
|
||||
"x": "projection_x",
|
||||
"y": "projection_y"
|
||||
},
|
||||
"neighbors": "neighbors"
|
||||
}
|
||||
}
|
||||
|
||||
# 5. 生成数据集标识符
|
||||
hasher = Hasher()
|
||||
hasher.update(["custom_dataset"])
|
||||
hasher.update(metadata)
|
||||
identifier = hasher.hexdigest()
|
||||
|
||||
# 6. 创建DataSource
|
||||
dataset = DataSource(identifier, combined_df, metadata)
|
||||
|
||||
# 7. 获取静态文件路径
|
||||
static_path = str(
|
||||
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
||||
)
|
||||
|
||||
# 8. 创建并启动服务器
|
||||
app = make_server(dataset, static_path=static_path)
|
||||
import uvicorn
|
||||
print(f"Starting interactive viewer on http://{host}:{port}")
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
def visualize_csv_comparison(
|
||||
csv1_path: str,
|
||||
csv2_path: str,
|
||||
column1: str = "smiles",
|
||||
column2: str = "smiles",
|
||||
output_path: str = "comparison_visualization.png",
|
||||
label1: Optional[str] = None,
|
||||
label2: Optional[str] = None,
|
||||
launch_interactive: bool = False,
|
||||
port: int = 5055,
|
||||
host: str = "localhost",
|
||||
model: str = "all-MiniLM-L6-v2",
|
||||
batch_size: int = 32,
|
||||
umap_args: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Visualize two CSV files with embedding-atlas and create a combined plot with different colors.
|
||||
|
||||
Args:
|
||||
csv1_path: Path to the first CSV file
|
||||
csv2_path: Path to the second CSV file
|
||||
column1: Column name for the first CSV file (default: "smiles")
|
||||
column2: Column name for the second CSV file (default: "smiles")
|
||||
output_path: Output visualization file path
|
||||
label1: Label for the first dataset (default: filename)
|
||||
label2: Label for the second dataset (default: filename)
|
||||
launch_interactive: Whether to launch interactive viewer (default: False)
|
||||
port: Port for interactive viewer (default: 5055)
|
||||
host: Host for interactive viewer (default: "localhost")
|
||||
model: Embedding model to use (default: "all-MiniLM-L6-v2")
|
||||
batch_size: Batch size for embedding computation (default: 32)
|
||||
umap_args: UMAP arguments as dictionary (default: None)
|
||||
"""
|
||||
|
||||
# Generate default labels from filenames if not provided
|
||||
if label1 is None:
|
||||
label1 = os.path.splitext(os.path.basename(csv1_path))[0]
|
||||
if label2 is None:
|
||||
label2 = os.path.splitext(os.path.basename(csv2_path))[0]
|
||||
|
||||
# Read CSV files
|
||||
df1 = pd.read_csv(csv1_path)
|
||||
df2 = pd.read_csv(csv2_path)
|
||||
|
||||
# Add source column to identify origin of each data point
|
||||
df1['source'] = label1
|
||||
df2['source'] = label2
|
||||
|
||||
# Add column to identify which file the data came from
|
||||
df1['file_origin'] = 0 # First file
|
||||
df2['file_origin'] = 1 # Second file
|
||||
|
||||
# Combine dataframes
|
||||
combined_df = pd.concat([df1, df2], ignore_index=True)
|
||||
|
||||
# Determine which column to use for text projection (moved before projection check)
|
||||
text_column = column1 if column1 in combined_df.columns else column2
|
||||
|
||||
# Check if projection columns already exist
|
||||
has_projection = 'projection_x' in combined_df.columns and 'projection_y' in combined_df.columns
|
||||
|
||||
if not has_projection:
|
||||
print("Computing embeddings and projections...")
|
||||
try:
|
||||
from embedding_atlas.projection import compute_text_projection
|
||||
|
||||
# Use default UMAP args if not provided
|
||||
if umap_args is None:
|
||||
umap_args = {
|
||||
"n_neighbors": 15,
|
||||
"min_dist": 0.1,
|
||||
"metric": "cosine"
|
||||
}
|
||||
|
||||
# Compute projections
|
||||
compute_text_projection(
|
||||
combined_df,
|
||||
text=text_column,
|
||||
x="projection_x",
|
||||
y="projection_y",
|
||||
neighbors="neighbors",
|
||||
model=model,
|
||||
batch_size=batch_size,
|
||||
umap_args=umap_args
|
||||
)
|
||||
except ImportError:
|
||||
print("Error: embedding-atlas not found. Please install it with: pip install embedding-atlas")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error computing projections: {e}")
|
||||
return
|
||||
|
||||
# Create visualization
|
||||
print("Creating visualization...")
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# Plot data points with different colors based on file origin
|
||||
colors = ['blue', 'red']
|
||||
for i, (label, color) in enumerate(zip([label1, label2], colors)):
|
||||
subset = combined_df[combined_df['file_origin'] == i]
|
||||
if not subset.empty:
|
||||
plt.scatter(
|
||||
subset['projection_x'],
|
||||
subset['projection_y'],
|
||||
c=color,
|
||||
label=label,
|
||||
alpha=0.6,
|
||||
s=30
|
||||
)
|
||||
|
||||
plt.xlabel('Projection X')
|
||||
plt.ylabel('Projection Y')
|
||||
plt.title('Embedding Atlas Visualization Comparison')
|
||||
plt.legend(loc='upper right')
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
# Save the plot
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"Visualization saved to {output_path}")
|
||||
|
||||
# Add _row_index column for embedding-atlas DataSource
|
||||
combined_df['_row_index'] = range(len(combined_df))
|
||||
|
||||
# Save dataset for interactive viewing
|
||||
parquet_path = "combined_dataset.parquet"
|
||||
combined_df.to_parquet(parquet_path, index=False)
|
||||
print(f"Dataset saved to {parquet_path}")
|
||||
|
||||
# Launch interactive viewer
|
||||
if launch_interactive:
|
||||
launch_interactive_viewer(combined_df, text_column, port, host)
|
||||
print(f"Interactive viewer available at http://{host}:{port}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Command-line interface for the visualization script."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Visualize two CSV files using embedding-atlas with different colors",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python visualize_csv_comparison.py file1.csv file2.csv
|
||||
python visualize_csv_comparison.py file1.csv file2.csv --column1 smiles --column2 SMILES
|
||||
python visualize_csv_comparison.py file1.csv file2.csv --label1 "Dataset A" --label2 "Dataset B"
|
||||
python visualize_csv_comparison.py file1.csv file2.csv --output comparison.png
|
||||
python visualize_csv_comparison.py file1.csv file2.csv --interactive --port 8080
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument("csv1", help="Path to the first CSV file")
|
||||
parser.add_argument("csv2", help="Path to the second CSV file")
|
||||
parser.add_argument("--column1", default="smiles", help="Column name for the first CSV file (default: smiles)")
|
||||
parser.add_argument("--column2", default="smiles", help="Column name for the second CSV file (default: smiles)")
|
||||
parser.add_argument("--output", "-o", default="comparison_visualization.png", help="Output visualization file path")
|
||||
parser.add_argument("--label1", help="Label for the first dataset (default: filename)")
|
||||
parser.add_argument("--label2", help="Label for the second dataset (default: filename)")
|
||||
parser.add_argument("--interactive", "-i", action="store_true", help="Launch interactive viewer")
|
||||
parser.add_argument("--port", "-p", type=int, default=5055, help="Port for interactive viewer (default: 5055)")
|
||||
parser.add_argument("--host", default="localhost", help="Host for interactive viewer (default: localhost)")
|
||||
parser.add_argument("--model", default="all-MiniLM-L6-v2", help="Embedding model to use (default: all-MiniLM-L6-v2)")
|
||||
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for embedding computation (default: 32)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse UMAP arguments
|
||||
umap_args = {
|
||||
"n_neighbors": 15,
|
||||
"min_dist": 0.1,
|
||||
"metric": "cosine"
|
||||
}
|
||||
|
||||
visualize_csv_comparison(
|
||||
args.csv1,
|
||||
args.csv2,
|
||||
args.column1,
|
||||
args.column2,
|
||||
args.output,
|
||||
args.label1,
|
||||
args.label2,
|
||||
args.interactive,
|
||||
args.port,
|
||||
args.host,
|
||||
args.model,
|
||||
args.batch_size,
|
||||
umap_args
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user