修复CLI使用的是props字段,而你的脚本使用的是columns字段的问题确保正确 python 能正确加载
This commit is contained in:
@@ -13,24 +13,30 @@ from typing import Optional, List, Dict, Any
|
||||
import numpy as np
|
||||
|
||||
def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 5055, host: str = "0.0.0.0"):
|
||||
"""使用Python API启动交互式服务器"""
|
||||
"""使用Python API启动交互式服务器
|
||||
|
||||
Args:
|
||||
df: 原始的pandas DataFrame,包含投影坐标和邻居信息(作为Python对象,而非JSON字符串)
|
||||
text_column: 用作文本内容的列名
|
||||
port: 启动服务器的端口
|
||||
host: 服务器绑定的主机地址
|
||||
"""
|
||||
try:
|
||||
from embedding_atlas.server import make_server
|
||||
from embedding_atlas.data_source import DataSource
|
||||
from embedding_atlas.utils import Hasher
|
||||
import pathlib
|
||||
|
||||
# 创建metadata - 添加database配置
|
||||
# 创建metadata - 使用与CLI一致的props字段结构
|
||||
metadata = {
|
||||
"columns": {
|
||||
"id": "_row_index",
|
||||
"text": text_column,
|
||||
"embedding": {"x": "projection_x", "y": "projection_y"},
|
||||
"neighbors": "__neighbors" # 使用双下划线
|
||||
},
|
||||
"database": {
|
||||
"type": "wasm",
|
||||
"load": True # 关键: 告诉前端加载数据
|
||||
"props": {
|
||||
"data": {
|
||||
"id": "_row_index",
|
||||
"text": text_column,
|
||||
"projection": {"x": "projection_x", "y": "projection_y"},
|
||||
"neighbors": "__neighbors"
|
||||
},
|
||||
"initialState": {"version": "0.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +46,8 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50
|
||||
hasher.update(metadata)
|
||||
identifier = hasher.hexdigest()
|
||||
|
||||
# 创建DataSource
|
||||
# 创建DataSource - 直接使用传入的原始DataFrame
|
||||
# 确保df['__neighbors'] 是 Python dict + numpy arrays 的结构,而不是 JSON 字符串
|
||||
dataset = DataSource(identifier, df, metadata)
|
||||
|
||||
# 获取静态文件路径
|
||||
@@ -49,7 +56,7 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50
|
||||
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
||||
)
|
||||
|
||||
# 创建并启动服务器
|
||||
# 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
|
||||
app = make_server(dataset, static_path=static_path)
|
||||
|
||||
import uvicorn
|
||||
@@ -115,38 +122,55 @@ def create_embedding_service(
|
||||
print(f"Error computing projections: {e}")
|
||||
return
|
||||
|
||||
# 4. 创建metadata (修复: 添加database配置和正确的neighbors列名)
|
||||
# 4. 创建metadata (使用与CLI一致的props字段结构)
|
||||
metadata = {
|
||||
"columns": {
|
||||
"id": "_row_index",
|
||||
"text": text_column,
|
||||
"embedding": {
|
||||
"x": "projection_x",
|
||||
"y": "projection_y"
|
||||
"props": {
|
||||
"data": {
|
||||
"id": "_row_index",
|
||||
"text": text_column,
|
||||
"projection": {
|
||||
"x": "projection_x",
|
||||
"y": "projection_y"
|
||||
},
|
||||
"neighbors": "__neighbors"
|
||||
},
|
||||
"neighbors": "__neighbors" # 修复: 使用双下划线
|
||||
},
|
||||
"database": {
|
||||
"type": "wasm",
|
||||
"load": True # 修复: 添加database配置
|
||||
"initialState": {"version": "0.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
# 5. 生成数据集标识符
|
||||
# 5. 保存原始DataFrame用于DataSource
|
||||
df_for_datasource = combined_df.copy()
|
||||
|
||||
# 6. 转换neighbors列为JSON字符串(仅用于保存parquet)
|
||||
if '__neighbors' in combined_df.columns:
|
||||
import json
|
||||
combined_df['__neighbors'] = combined_df['__neighbors'].apply(
|
||||
lambda x: json.dumps({
|
||||
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
|
||||
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
|
||||
}) if x is not None else None
|
||||
)
|
||||
|
||||
# 7. 保存parquet文件
|
||||
parquet_path = "combined_dataset.parquet"
|
||||
combined_df.to_parquet(parquet_path, index=False)
|
||||
print(f"Dataset saved to {parquet_path}")
|
||||
|
||||
# 8. 生成数据集标识符
|
||||
hasher = Hasher()
|
||||
hasher.update(["custom_dataset"])
|
||||
hasher.update(metadata)
|
||||
identifier = hasher.hexdigest()
|
||||
|
||||
# 6. 创建DataSource
|
||||
dataset = DataSource(identifier, combined_df, metadata)
|
||||
# 9. 使用原始DataFrame创建DataSource
|
||||
dataset = DataSource(identifier, df_for_datasource, metadata)
|
||||
|
||||
# 7. 获取静态文件路径
|
||||
# 10. 获取静态文件路径
|
||||
static_path = str(
|
||||
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
||||
)
|
||||
|
||||
# 8. 创建并启动服务器
|
||||
# 11. 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
|
||||
app = make_server(dataset, static_path=static_path)
|
||||
import uvicorn
|
||||
print(f"Starting interactive viewer on http://{host}:{port}")
|
||||
@@ -298,14 +322,35 @@ def visualize_csv_comparison(
|
||||
|
||||
print(f"Visualization saved to {output_path}")
|
||||
|
||||
# Keep a copy of the original DataFrame for the interactive viewer
|
||||
# The viewer expects Python objects, not JSON strings
|
||||
df_for_viewer = combined_df.copy()
|
||||
|
||||
# Convert __neighbors column to JSON strings before saving
|
||||
# This is needed because DuckDB-WASM expects JSON strings for complex data types
|
||||
if '__neighbors' in combined_df.columns:
|
||||
import json
|
||||
# 为保存到 parquet 文件的数据转换为 JSON 字符串
|
||||
combined_df_for_save = combined_df.copy()
|
||||
combined_df_for_save['__neighbors'] = combined_df['__neighbors'].apply(
|
||||
lambda x: json.dumps({
|
||||
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
|
||||
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
|
||||
}) if x is not None else None
|
||||
)
|
||||
|
||||
# Save dataset for interactive viewing
|
||||
parquet_path = "combined_dataset.parquet"
|
||||
combined_df.to_parquet(parquet_path, index=False)
|
||||
# 保存转换为 JSON 字符串的 DataFrame
|
||||
if '__neighbors' in combined_df.columns:
|
||||
combined_df_for_save.to_parquet(parquet_path, index=False)
|
||||
else:
|
||||
combined_df.to_parquet(parquet_path, index=False)
|
||||
print(f"Dataset saved to {parquet_path}")
|
||||
|
||||
# Launch interactive viewer
|
||||
# Launch interactive viewer with original DataFrame (not converted to JSON)
|
||||
if launch_interactive:
|
||||
launch_interactive_viewer(combined_df, text_column, port, host)
|
||||
launch_interactive_viewer(df_for_viewer, text_column, port, host)
|
||||
print(f"Interactive viewer available at http://{host}:{port}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user