修复CLI使用的是props字段,而你的脚本使用的是columns字段的问题确保正确 python 能正确加载
This commit is contained in:
@@ -13,24 +13,30 @@ from typing import Optional, List, Dict, Any
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 5055, host: str = "0.0.0.0"):
|
def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 5055, host: str = "0.0.0.0"):
|
||||||
"""使用Python API启动交互式服务器"""
|
"""使用Python API启动交互式服务器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 原始的pandas DataFrame,包含投影坐标和邻居信息(作为Python对象,而非JSON字符串)
|
||||||
|
text_column: 用作文本内容的列名
|
||||||
|
port: 启动服务器的端口
|
||||||
|
host: 服务器绑定的主机地址
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
from embedding_atlas.server import make_server
|
from embedding_atlas.server import make_server
|
||||||
from embedding_atlas.data_source import DataSource
|
from embedding_atlas.data_source import DataSource
|
||||||
from embedding_atlas.utils import Hasher
|
from embedding_atlas.utils import Hasher
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
# 创建metadata - 添加database配置
|
# 创建metadata - 使用与CLI一致的props字段结构
|
||||||
metadata = {
|
metadata = {
|
||||||
"columns": {
|
"props": {
|
||||||
"id": "_row_index",
|
"data": {
|
||||||
"text": text_column,
|
"id": "_row_index",
|
||||||
"embedding": {"x": "projection_x", "y": "projection_y"},
|
"text": text_column,
|
||||||
"neighbors": "__neighbors" # 使用双下划线
|
"projection": {"x": "projection_x", "y": "projection_y"},
|
||||||
},
|
"neighbors": "__neighbors"
|
||||||
"database": {
|
},
|
||||||
"type": "wasm",
|
"initialState": {"version": "0.0.0"}
|
||||||
"load": True # 关键: 告诉前端加载数据
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,7 +46,8 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50
|
|||||||
hasher.update(metadata)
|
hasher.update(metadata)
|
||||||
identifier = hasher.hexdigest()
|
identifier = hasher.hexdigest()
|
||||||
|
|
||||||
# 创建DataSource
|
# 创建DataSource - 直接使用传入的原始DataFrame
|
||||||
|
# 确保df['__neighbors'] 是 Python dict + numpy arrays 的结构,而不是 JSON 字符串
|
||||||
dataset = DataSource(identifier, df, metadata)
|
dataset = DataSource(identifier, df, metadata)
|
||||||
|
|
||||||
# 获取静态文件路径
|
# 获取静态文件路径
|
||||||
@@ -49,7 +56,7 @@ def launch_interactive_viewer(df: pd.DataFrame, text_column: str, port: int = 50
|
|||||||
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建并启动服务器
|
# 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
|
||||||
app = make_server(dataset, static_path=static_path)
|
app = make_server(dataset, static_path=static_path)
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@@ -115,38 +122,55 @@ def create_embedding_service(
|
|||||||
print(f"Error computing projections: {e}")
|
print(f"Error computing projections: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 4. 创建metadata (修复: 添加database配置和正确的neighbors列名)
|
# 4. 创建metadata (使用与CLI一致的props字段结构)
|
||||||
metadata = {
|
metadata = {
|
||||||
"columns": {
|
"props": {
|
||||||
"id": "_row_index",
|
"data": {
|
||||||
"text": text_column,
|
"id": "_row_index",
|
||||||
"embedding": {
|
"text": text_column,
|
||||||
"x": "projection_x",
|
"projection": {
|
||||||
"y": "projection_y"
|
"x": "projection_x",
|
||||||
|
"y": "projection_y"
|
||||||
|
},
|
||||||
|
"neighbors": "__neighbors"
|
||||||
},
|
},
|
||||||
"neighbors": "__neighbors" # 修复: 使用双下划线
|
"initialState": {"version": "0.0.0"}
|
||||||
},
|
|
||||||
"database": {
|
|
||||||
"type": "wasm",
|
|
||||||
"load": True # 修复: 添加database配置
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 5. 生成数据集标识符
|
# 5. 保存原始DataFrame用于DataSource
|
||||||
|
df_for_datasource = combined_df.copy()
|
||||||
|
|
||||||
|
# 6. 转换neighbors列为JSON字符串(仅用于保存parquet)
|
||||||
|
if '__neighbors' in combined_df.columns:
|
||||||
|
import json
|
||||||
|
combined_df['__neighbors'] = combined_df['__neighbors'].apply(
|
||||||
|
lambda x: json.dumps({
|
||||||
|
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
|
||||||
|
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
|
||||||
|
}) if x is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 7. 保存parquet文件
|
||||||
|
parquet_path = "combined_dataset.parquet"
|
||||||
|
combined_df.to_parquet(parquet_path, index=False)
|
||||||
|
print(f"Dataset saved to {parquet_path}")
|
||||||
|
|
||||||
|
# 8. 生成数据集标识符
|
||||||
hasher = Hasher()
|
hasher = Hasher()
|
||||||
hasher.update(["custom_dataset"])
|
hasher.update(["custom_dataset"])
|
||||||
hasher.update(metadata)
|
hasher.update(metadata)
|
||||||
identifier = hasher.hexdigest()
|
identifier = hasher.hexdigest()
|
||||||
|
|
||||||
# 6. 创建DataSource
|
# 9. 使用原始DataFrame创建DataSource
|
||||||
dataset = DataSource(identifier, combined_df, metadata)
|
dataset = DataSource(identifier, df_for_datasource, metadata)
|
||||||
|
|
||||||
# 7. 获取静态文件路径
|
# 10. 获取静态文件路径
|
||||||
static_path = str(
|
static_path = str(
|
||||||
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
(pathlib.Path(embedding_atlas.__file__).parent / "static").resolve()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 8. 创建并启动服务器
|
# 11. 创建并启动服务器 - 使用默认的duckdb_uri参数(即"wasm")
|
||||||
app = make_server(dataset, static_path=static_path)
|
app = make_server(dataset, static_path=static_path)
|
||||||
import uvicorn
|
import uvicorn
|
||||||
print(f"Starting interactive viewer on http://{host}:{port}")
|
print(f"Starting interactive viewer on http://{host}:{port}")
|
||||||
@@ -298,14 +322,35 @@ def visualize_csv_comparison(
|
|||||||
|
|
||||||
print(f"Visualization saved to {output_path}")
|
print(f"Visualization saved to {output_path}")
|
||||||
|
|
||||||
|
# Keep a copy of the original DataFrame for the interactive viewer
|
||||||
|
# The viewer expects Python objects, not JSON strings
|
||||||
|
df_for_viewer = combined_df.copy()
|
||||||
|
|
||||||
|
# Convert __neighbors column to JSON strings before saving
|
||||||
|
# This is needed because DuckDB-WASM expects JSON strings for complex data types
|
||||||
|
if '__neighbors' in combined_df.columns:
|
||||||
|
import json
|
||||||
|
# 为保存到 parquet 文件的数据转换为 JSON 字符串
|
||||||
|
combined_df_for_save = combined_df.copy()
|
||||||
|
combined_df_for_save['__neighbors'] = combined_df['__neighbors'].apply(
|
||||||
|
lambda x: json.dumps({
|
||||||
|
'ids': x['ids'].tolist() if hasattr(x['ids'], 'tolist') else list(x['ids']),
|
||||||
|
'distances': x['distances'].tolist() if hasattr(x['distances'], 'tolist') else list(x['distances'])
|
||||||
|
}) if x is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
# Save dataset for interactive viewing
|
# Save dataset for interactive viewing
|
||||||
parquet_path = "combined_dataset.parquet"
|
parquet_path = "combined_dataset.parquet"
|
||||||
combined_df.to_parquet(parquet_path, index=False)
|
# 保存转换为 JSON 字符串的 DataFrame
|
||||||
|
if '__neighbors' in combined_df.columns:
|
||||||
|
combined_df_for_save.to_parquet(parquet_path, index=False)
|
||||||
|
else:
|
||||||
|
combined_df.to_parquet(parquet_path, index=False)
|
||||||
print(f"Dataset saved to {parquet_path}")
|
print(f"Dataset saved to {parquet_path}")
|
||||||
|
|
||||||
# Launch interactive viewer
|
# Launch interactive viewer with original DataFrame (not converted to JSON)
|
||||||
if launch_interactive:
|
if launch_interactive:
|
||||||
launch_interactive_viewer(combined_df, text_column, port, host)
|
launch_interactive_viewer(df_for_viewer, text_column, port, host)
|
||||||
print(f"Interactive viewer available at http://{host}:{port}")
|
print(f"Interactive viewer available at http://{host}:{port}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user