修复CLI使用的是props字段,而你的脚本使用的是columns字段的问题确保正确 python 能正确加载

This commit is contained in:
2025-10-23 20:49:22 +08:00
parent 60c5ce152b
commit 04d1106e53

View File

@@ -13,24 +13,30 @@ from typing import Optional, List, Dict, Any
import numpy as np
def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 5055, host: str = "0.0.0.0"):
"""使用Python API启动交互式服务器"""
"""使用Python API启动交互式服务器
Args:
df: 原始的pandas DataFrame包含投影坐标和邻居信息作为Python对象而非JSON字符串
text_column: 用作文本内容的列名
port: 启动服务器的端口
host: 服务器绑定的主机地址
"""
try:
from embedding_atlas.server import make_server
from embedding_atlas.data_source import DataSource
from embedding_atlas.utils import Hasher
import pathlib
# 创建metadata - 添加database配置
# 创建metadata - 使用与CLI一致的props字段结构
metadata = {
"columns": {
"props": {
"data": {
"id": "_row_index",
"text": text_column,
"embedding": {"x": "projection_x", "y": "projection_y"},
"neighbors": "__neighbors" # 使用双下划线
"projection": {"x": "projection_x", "y": "projection_y"},
"neighbors": "__neighbors"
},
"database": {
"type": "wasm",
"load": True # 关键: 告诉前端加载数据
"initialState": {"version": "0.0.0"}
}
}
@@ -40,7 +46,8 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50
hasher.update(metadata)
identifier = hasher.hexdigest()
# 创建DataSource
# 创建DataSource - 直接使用传入的原始DataFrame
# 确保df['__neighbors'] 是 Python dict + numpy arrays 的结构,而不是 JSON 字符串
dataset = DataSource(identifier, df, metadata)
# 获取静态文件路径
@@ -49,7 +56,7 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
)
# 创建并启动服务器
# 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
app = make_server(dataset, static_path=static_path)
import uvicorn
@@ -115,38 +122,55 @@ def create_embedding_service(
print(f"Error computing projections: {e}")
return
# 4. 创建metadata (修复: 添加database配置和正确的neighbors列名)
# 4. 创建metadata (使用与CLI一致的props字段结构)
metadata = {
"columns": {
"props": {
"data": {
"id": "_row_index",
"text": text_column,
"embedding": {
"projection": {
"x": "projection_x",
"y": "projection_y"
},
"neighbors": "__neighbors" # 修复: 使用双下划线
"neighbors": "__neighbors"
},
"database": {
"type": "wasm",
"load": True # 修复: 添加database配置
"initialState": {"version": "0.0.0"}
}
}
# 5. 生成数据集标识符
# 5. 保存原始DataFrame用于DataSource
df_for_datasource = combined_df.copy()
# 6. 转换neighbors列为JSON字符串(仅用于保存parquet)
if '__neighbors' in combined_df.columns:
import json
combined_df['__neighbors'] = combined_df['__neighbors'].apply(
lambda x: json.dumps({
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
}) if x is not None else None
)
# 7. 保存parquet文件
parquet_path = "combined_dataset.parquet"
combined_df.to_parquet(parquet_path, index=False)
print(f"Dataset saved to {parquet_path}")
# 8. 生成数据集标识符
hasher = Hasher()
hasher.update(["custom_dataset"])
hasher.update(metadata)
identifier = hasher.hexdigest()
# 6. 创建DataSource
dataset = DataSource(identifier, combined_df, metadata)
# 9. 使用原始DataFrame创建DataSource
dataset = DataSource(identifier, df_for_datasource, metadata)
# 7. 获取静态文件路径
# 10. 获取静态文件路径
static_path = str(
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
)
# 8. 创建并启动服务器
# 11. 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
app = make_server(dataset, static_path=static_path)
import uvicorn
print(f"Starting interactive viewer on http://{host}:{port}")
@@ -298,14 +322,35 @@ def visualize_csv_comparison(
print(f"Visualization saved to {output_path}")
# Keep a copy of the original DataFrame for the interactive viewer
# The viewer expects Python objects, not JSON strings
df_for_viewer = combined_df.copy()
# Convert __neighbors column to JSON strings before saving
# This is needed because DuckDB-WASM expects JSON strings for complex data types
if '__neighbors' in combined_df.columns:
import json
# 为保存到 parquet 文件的数据转换为 JSON 字符串
combined_df_for_save = combined_df.copy()
combined_df_for_save['__neighbors'] = combined_df['__neighbors'].apply(
lambda x: json.dumps({
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
}) if x is not None else None
)
# Save dataset for interactive viewing
parquet_path = "combined_dataset.parquet"
# 保存转换为 JSON 字符串的 DataFrame
if '__neighbors' in combined_df.columns:
combined_df_for_save.to_parquet(parquet_path, index=False)
else:
combined_df.to_parquet(parquet_path, index=False)
print(f"Dataset saved to {parquet_path}")
# Launch interactive viewer
# Launch interactive viewer with original DataFrame (not converted to JSON)
if launch_interactive:
launch_interactive_viewer(combined_df, text_column, port, host)
launch_interactive_viewer(df_for_viewer, text_column, port, host)
print(f"Interactive viewer available at http://{host}:{port}")