From deecbfe0fc05f9fae360f5004a0c74f32a97584a Mon Sep 17 00:00:00 2001 From: lingyuzeng Date: Thu, 23 Oct 2025 18:09:33 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dcreate=5Fembedding=5Fservice?= =?UTF-8?q?=E5=92=8Cvisualize=5Fcsv=5Fcomparison=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 修复create_embedding_service函数: - 添加缺失的导入语句 - 修正metadata中neighbors列名不一致问题 - 添加database配置确保数据能正确加载 2. 优化visualize_csv_comparison函数: - 调整_row_index列添加时机 - 添加CSV文件读取错误处理 - 添加列名验证功能 - 保持与create_embedding_service一致的neighbors列名 --- src/visualization/comparison.py | 62 ++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/src/visualization/comparison.py b/src/visualization/comparison.py index 0182a15..ce47564 100644 --- a/src/visualization/comparison.py +++ b/src/visualization/comparison.py @@ -72,22 +72,13 @@ def create_embedding_service( batch_size: int = 32, umap_args: Optional[Dict[str, Any]] = None ): - """ - 从两组文本数据创建embedding可视化服务 - - Args: - texts1: 第一组文本列表 - texts2: 第二组文本列表 - labels: 两组数据的标签 - port: 服务器端口 - host: 服务器主机 - text_column: 文本列名称 - model: 使用的嵌入模型 - batch_size: 批处理大小 - umap_args: UMAP参数字典 - """ + """从两组文本数据创建embedding可视化服务""" import pathlib import embedding_atlas + from embedding_atlas.projection import compute_text_projection + from embedding_atlas.server import make_server + from embedding_atlas.data_source import DataSource + from embedding_atlas.utils import Hasher # 默认UMAP参数 if umap_args is None: @@ -107,7 +98,6 @@ def create_embedding_service( # 3. 计算embeddings和投影 try: - from embedding_atlas.projection import compute_text_projection compute_text_projection( combined_df, text=text_column, @@ -125,7 +115,7 @@ def create_embedding_service( print(f"Error computing projections: {e}") return - # 4. 创建metadata + # 4. 创建metadata (修复: 添加database配置和正确的neighbors列名) metadata = { "columns": { "id": "_row_index", @@ -134,7 +124,11 @@ def create_embedding_service( "x": "projection_x", "y": "projection_y" }, - "neighbors": "neighbors" + "neighbors": "__neighbors" # 修复: 使用双下划线 + }, + "database": { + "type": "wasm", + "load": True # 修复: 添加database配置 } } @@ -198,9 +192,29 @@ def visualize_csv_comparison( if label2 is None: label2 = os.path.splitext(os.path.basename(csv2_path))[0] - # Read CSV files - df1 = pd.read_csv(csv1_path) - df2 = pd.read_csv(csv2_path) + # Read CSV files with error handling + try: + df1 = pd.read_csv(csv1_path) + df2 = pd.read_csv(csv2_path) + except FileNotFoundError as e: + print(f"Error: CSV file not found - {e}") + return + except pd.errors.EmptyDataError: + print("Error: One of the CSV files is empty") + return + except Exception as e: + print(f"Error reading CSV files: {e}") + return + + # Validate columns + if column1 not in df1.columns: + print(f"Error: Column '{column1}' not found in {csv1_path}") + print(f"Available columns: {list(df1.columns)}") + return + if column2 not in df2.columns: + print(f"Error: Column '{column2}' not found in {csv2_path}") + print(f"Available columns: {list(df2.columns)}") + return # Add source column to identify origin of each data point df1['source'] = label1 @@ -213,7 +227,10 @@ def visualize_csv_comparison( # Combine dataframes combined_df = pd.concat([df1, df2], ignore_index=True) - # Determine which column to use for text projection (moved before projection check) + # Add _row_index column for embedding-atlas DataSource (moved to earlier) + combined_df['_row_index'] = range(len(combined_df)) + + # Determine which column to use for text projection text_column = column1 if column1 in combined_df.columns else column2 # Check if projection columns already exist @@ -281,9 +298,6 @@ def visualize_csv_comparison( print(f"Visualization saved to {output_path}") - # Add _row_index column for embedding-atlas DataSource - combined_df['_row_index'] = range(len(combined_df)) - # Save dataset for interactive viewing parquet_path = "combined_dataset.parquet" combined_df.to_parquet(parquet_path, index=False)