From 04d1106e534f83cfa1ffb662da901f8b2fe96535 Mon Sep 17 00:00:00 2001 From: lingyuzeng Date: Thu, 23 Oct 2025 20:49:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DCLI=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E7=9A=84=E6=98=AFprops=E5=AD=97=E6=AE=B5=EF=BC=8C=E8=80=8C?= =?UTF-8?q?=E4=BD=A0=E7=9A=84=E8=84=9A=E6=9C=AC=E4=BD=BF=E7=94=A8=E7=9A=84?= =?UTF-8?q?=E6=98=AFcolumns=E5=AD=97=E6=AE=B5=E7=9A=84=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E7=A1=AE=E4=BF=9D=E6=AD=A3=E7=A1=AE=20python=20=E8=83=BD?= =?UTF-8?q?=E6=AD=A3=E7=A1=AE=E5=8A=A0=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/visualization/comparison.py | 111 ++++++++++++++++++++++---------- 1 file changed, 78 insertions(+), 33 deletions(-) diff --git a/src/visualization/comparison.py b/src/visualization/comparison.py index 62ff6bf..f86b852 100644 --- a/src/visualization/comparison.py +++ b/src/visualization/comparison.py @@ -13,24 +13,30 @@ from typing import Optional, List, Dict, Any import numpy as np def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 5055, host: str = "0.0.0.0"): - """使用Python API启动交互式服务器""" + """使用Python API启动交互式服务器 + + Args: + df: 原始的pandas DataFrame,包含投影坐标和邻居信息(作为Python对象,而非JSON字符串) + text_column: 用作文本内容的列名 + port: 启动服务器的端口 + host: 服务器绑定的主机地址 + """ try: from embedding_atlas.server import make_server from embedding_atlas.data_source import DataSource from embedding_atlas.utils import Hasher import pathlib - # 创建metadata - 添加database配置 + # 创建metadata - 使用与CLI一致的props字段结构 metadata = { - "columns": { - "id": "_row_index", - "text": text_column, - "embedding": {"x": "projection_x", "y": "projection_y"}, - "neighbors": "__neighbors" # 使用双下划线 - }, - "database": { - "type": "wasm", - "load": True # 关键: 告诉前端加载数据 + "props": { + "data": { + "id": "_row_index", + "text": text_column, + "projection": {"x": "projection_x", "y": "projection_y"}, + "neighbors": "__neighbors" + }, + "initialState": {"version": "0.0.0"} } } @@ -40,7 +46,8 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50 hasher.update(metadata) identifier = hasher.hexdigest() - # 创建DataSource + # 创建DataSource - 直接使用传入的原始DataFrame + # 确保df['__neighbors'] 是 Python dict + numpy arrays 的结构,而不是 JSON 字符串 dataset = DataSource(identifier, df, metadata) # 获取静态文件路径 @@ -49,7 +56,7 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50 (pathlib.Path(embedding_atlas.__file__).parent / "static").resolve() ) - # 创建并启动服务器 + # 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm") app = make_server(dataset, static_path=static_path) import uvicorn @@ -115,38 +122,55 @@ def create_embedding_service( print(f"Error computing projections: {e}") return - # 4. 创建metadata (修复: 添加database配置和正确的neighbors列名) + # 4. 创建metadata (使用与CLI一致的props字段结构) metadata = { - "columns": { - "id": "_row_index", - "text": text_column, - "embedding": { - "x": "projection_x", - "y": "projection_y" + "props": { + "data": { + "id": "_row_index", + "text": text_column, + "projection": { + "x": "projection_x", + "y": "projection_y" + }, + "neighbors": "__neighbors" }, - "neighbors": "__neighbors" # 修复: 使用双下划线 - }, - "database": { - "type": "wasm", - "load": True # 修复: 添加database配置 + "initialState": {"version": "0.0.0"} } } - # 5. 生成数据集标识符 + # 5. 保存原始DataFrame用于DataSource + df_for_datasource = combined_df.copy() + + # 6. 转换neighbors列为JSON字符串(仅用于保存parquet) + if '__neighbors' in combined_df.columns: + import json + combined_df['__neighbors'] = combined_df['__neighbors'].apply( + lambda x: json.dumps({ + 'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']), + 'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances']) + }) if x is not None else None + ) + + # 7. 保存parquet文件 + parquet_path = "combined_dataset.parquet" + combined_df.to_parquet(parquet_path, index=False) + print(f"Dataset saved to {parquet_path}") + + # 8. 生成数据集标识符 hasher = Hasher() hasher.update(["custom_dataset"]) hasher.update(metadata) identifier = hasher.hexdigest() - # 6. 创建DataSource - dataset = DataSource(identifier, combined_df, metadata) + # 9. 使用原始DataFrame创建DataSource + dataset = DataSource(identifier, df_for_datasource, metadata) - # 7. 获取静态文件路径 + # 10. 获取静态文件路径 static_path = str( (pathlib.Path(embedding_atlas.__file__).parent / "static").resolve() ) - # 8. 创建并启动服务器 + # 11. 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm") app = make_server(dataset, static_path=static_path) import uvicorn print(f"Starting interactive viewer on http://{host}:{port}") @@ -298,14 +322,35 @@ def visualize_csv_comparison( print(f"Visualization saved to {output_path}") + # Keep a copy of the original DataFrame for the interactive viewer + # The viewer expects Python objects, not JSON strings + df_for_viewer = combined_df.copy() + + # Convert __neighbors column to JSON strings before saving + # This is needed because DuckDB-WASM expects JSON strings for complex data types + if '__neighbors' in combined_df.columns: + import json + # 为保存到 parquet 文件的数据转换为 JSON 字符串 + combined_df_for_save = combined_df.copy() + combined_df_for_save['__neighbors'] = combined_df['__neighbors'].apply( + lambda x: json.dumps({ + 'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']), + 'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances']) + }) if x is not None else None + ) + # Save dataset for interactive viewing parquet_path = "combined_dataset.parquet" - combined_df.to_parquet(parquet_path, index=False) + # 保存转换为 JSON 字符串的 DataFrame + if '__neighbors' in combined_df.columns: + combined_df_for_save.to_parquet(parquet_path, index=False) + else: + combined_df.to_parquet(parquet_path, index=False) print(f"Dataset saved to {parquet_path}") - # Launch interactive viewer + # Launch interactive viewer with original DataFrame (not converted to JSON) if launch_interactive: - launch_interactive_viewer(combined_df, text_column, port, host) + launch_interactive_viewer(df_for_viewer, text_column, port, host) print(f"Interactive viewer available at http://{host}:{port}")