修复create_embedding_service和visualize_csv_comparison函数中的问题
1. 修复create_embedding_service函数: - 添加缺失的导入语句 - 修正metadata中neighbors列名不一致问题 - 添加database配置确保数据能正确加载 2. 优化visualize_csv_comparison函数: - 调整_row_index列添加时机 - 添加CSV文件读取错误处理 - 添加列名验证功能 - 保持与create_embedding_service一致的neighbors列名
This commit is contained in:
@@ -72,22 +72,13 @@ def create_embedding_service(
|
||||
batch_size: int = 32,
|
||||
umap_args: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
从两组文本数据创建embedding可视化服务
|
||||
|
||||
Args:
|
||||
texts1: 第一组文本列表
|
||||
texts2: 第二组文本列表
|
||||
labels: 两组数据的标签
|
||||
port: 服务器端口
|
||||
host: 服务器主机
|
||||
text_column: 文本列名称
|
||||
model: 使用的嵌入模型
|
||||
batch_size: 批处理大小
|
||||
umap_args: UMAP参数字典
|
||||
"""
|
||||
"""从两组文本数据创建embedding可视化服务"""
|
||||
import pathlib
|
||||
import embedding_atlas
|
||||
from embedding_atlas.projection import compute_text_projection
|
||||
from embedding_atlas.server import make_server
|
||||
from embedding_atlas.data_source import DataSource
|
||||
from embedding_atlas.utils import Hasher
|
||||
|
||||
# 默认UMAP参数
|
||||
if umap_args is None:
|
||||
@@ -107,7 +98,6 @@ def create_embedding_service(
|
||||
|
||||
# 3. 计算embeddings和投影
|
||||
try:
|
||||
from embedding_atlas.projection import compute_text_projection
|
||||
compute_text_projection(
|
||||
combined_df,
|
||||
text=text_column,
|
||||
@@ -125,7 +115,7 @@ def create_embedding_service(
|
||||
print(f"Error computing projections: {e}")
|
||||
return
|
||||
|
||||
# 4. 创建metadata
|
||||
# 4. 创建metadata (修复: 添加database配置和正确的neighbors列名)
|
||||
metadata = {
|
||||
"columns": {
|
||||
"id": "_row_index",
|
||||
@@ -134,7 +124,11 @@ def create_embedding_service(
|
||||
"x": "projection_x",
|
||||
"y": "projection_y"
|
||||
},
|
||||
"neighbors": "neighbors"
|
||||
"neighbors": "__neighbors" # 修复: 使用双下划线
|
||||
},
|
||||
"database": {
|
||||
"type": "wasm",
|
||||
"load": True # 修复: 添加database配置
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,9 +192,29 @@ def visualize_csv_comparison(
|
||||
if label2 is None:
|
||||
label2 = os.path.splitext(os.path.basename(csv2_path))[0]
|
||||
|
||||
# Read CSV files
|
||||
# Read CSV files with error handling
|
||||
try:
|
||||
df1 = pd.read_csv(csv1_path)
|
||||
df2 = pd.read_csv(csv2_path)
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: CSV file not found - {e}")
|
||||
return
|
||||
except pd.errors.EmptyDataError:
|
||||
print("Error: One of the CSV files is empty")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error reading CSV files: {e}")
|
||||
return
|
||||
|
||||
# Validate columns
|
||||
if column1 not in df1.columns:
|
||||
print(f"Error: Column '{column1}' not found in {csv1_path}")
|
||||
print(f"Available columns: {list(df1.columns)}")
|
||||
return
|
||||
if column2 not in df2.columns:
|
||||
print(f"Error: Column '{column2}' not found in {csv2_path}")
|
||||
print(f"Available columns: {list(df2.columns)}")
|
||||
return
|
||||
|
||||
# Add source column to identify origin of each data point
|
||||
df1['source'] = label1
|
||||
@@ -213,7 +227,10 @@ def visualize_csv_comparison(
|
||||
# Combine dataframes
|
||||
combined_df = pd.concat([df1, df2], ignore_index=True)
|
||||
|
||||
# Determine which column to use for text projection (moved before projection check)
|
||||
# Add _row_index column for embedding-atlas DataSource (moved to earlier)
|
||||
combined_df['_row_index'] = range(len(combined_df))
|
||||
|
||||
# Determine which column to use for text projection
|
||||
text_column = column1 if column1 in combined_df.columns else column2
|
||||
|
||||
# Check if projection columns already exist
|
||||
@@ -281,9 +298,6 @@ def visualize_csv_comparison(
|
||||
|
||||
print(f"Visualization saved to {output_path}")
|
||||
|
||||
# Add _row_index column for embedding-atlas DataSource
|
||||
combined_df['_row_index'] = range(len(combined_df))
|
||||
|
||||
# Save dataset for interactive viewing
|
||||
parquet_path = "combined_dataset.parquet"
|
||||
combined_df.to_parquet(parquet_path, index=False)
|
||||
|
||||
Reference in New Issue
Block a user