修复create_embedding_service和visualize_csv_comparison函数中的问题

1. 修复create_embedding_service函数:
   - 添加缺失的导入语句
   - 修正metadata中neighbors列名不一致问题
   - 添加database配置确保数据能正确加载

2. 优化visualize_csv_comparison函数:
   - 调整_row_index列添加时机
   - 添加CSV文件读取错误处理
   - 添加列名验证功能
   - 保持与create_embedding_service一致的neighbors列名
This commit is contained in:
2025-10-23 18:09:33 +08:00
parent 991bcc491f
commit deecbfe0fc

View File

@@ -72,22 +72,13 @@ def create_embedding_service(
batch_size: int = 32,
umap_args: Optional[Dict[str, Any]] = None
):
"""
从两组文本数据创建embedding可视化服务
Args:
texts1: 第一组文本列表
texts2: 第二组文本列表
labels: 两组数据的标签
port: 服务器端口
host: 服务器主机
text_column: 文本列名称
model: 使用的嵌入模型
batch_size: 批处理大小
umap_args: UMAP参数字典
"""
"""从两组文本数据创建embedding可视化服务"""
import pathlib
import embedding_atlas
from embedding_atlas.projection import compute_text_projection
from embedding_atlas.server import make_server
from embedding_atlas.data_source import DataSource
from embedding_atlas.utils import Hasher
# 默认UMAP参数
if umap_args is None:
@@ -107,7 +98,6 @@ def create_embedding_service(
# 3. 计算embeddings和投影
try:
from embedding_atlas.projection import compute_text_projection
compute_text_projection(
combined_df,
text=text_column,
@@ -125,7 +115,7 @@ def create_embedding_service(
print(f"Error computing projections: {e}")
return
# 4. 创建metadata
# 4. 创建metadata (修复: 添加database配置和正确的neighbors列名)
metadata = {
"columns": {
"id": "_row_index",
@@ -134,7 +124,11 @@ def create_embedding_service(
"x": "projection_x",
"y": "projection_y"
},
"neighbors": "neighbors"
"neighbors": "__neighbors" # 修复: 使用双下划线
},
"database": {
"type": "wasm",
"load": True # 修复: 添加database配置
}
}
@@ -198,9 +192,29 @@ def visualize_csv_comparison(
if label2 is None:
label2 = os.path.splitext(os.path.basename(csv2_path))[0]
# Read CSV files
df1 = pd.read_csv(csv1_path)
df2 = pd.read_csv(csv2_path)
# Read CSV files with error handling
try:
df1 = pd.read_csv(csv1_path)
df2 = pd.read_csv(csv2_path)
except FileNotFoundError as e:
print(f"Error: CSV file not found - {e}")
return
except pd.errors.EmptyDataError:
print("Error: One of the CSV files is empty")
return
except Exception as e:
print(f"Error reading CSV files: {e}")
return
# Validate columns
if column1 not in df1.columns:
print(f"Error: Column '{column1}' not found in {csv1_path}")
print(f"Available columns: {list(df1.columns)}")
return
if column2 not in df2.columns:
print(f"Error: Column '{column2}' not found in {csv2_path}")
print(f"Available columns: {list(df2.columns)}")
return
# Add source column to identify origin of each data point
df1['source'] = label1
@@ -213,7 +227,10 @@ def visualize_csv_comparison(
# Combine dataframes
combined_df = pd.concat([df1, df2], ignore_index=True)
# Determine which column to use for text projection (moved before projection check)
# Add _row_index column for embedding-atlas DataSource (moved to earlier)
combined_df['_row_index'] = range(len(combined_df))
# Determine which column to use for text projection
text_column = column1 if column1 in combined_df.columns else column2
# Check if projection columns already exist
@@ -281,9 +298,6 @@ def visualize_csv_comparison(
print(f"Visualization saved to {output_path}")
# Add _row_index column for embedding-atlas DataSource
combined_df['_row_index'] = range(len(combined_df))
# Save dataset for interactive viewing
parquet_path = "combined_dataset.parquet"
combined_df.to_parquet(parquet_path, index=False)