修复create_embedding_service和visualize_csv_comparison函数中的问题
1. 修复create_embedding_service函数: - 添加缺失的导入语句 - 修正metadata中neighbors列名不一致问题 - 添加database配置确保数据能正确加载 2. 优化visualize_csv_comparison函数: - 调整_row_index列添加时机 - 添加CSV文件读取错误处理 - 添加列名验证功能 - 保持与create_embedding_service一致的neighbors列名
This commit is contained in:
@@ -72,22 +72,13 @@ def create_embedding_service(
|
|||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
umap_args: Optional[Dict[str, Any]] = None
|
umap_args: Optional[Dict[str, Any]] = None
|
||||||
):
|
):
|
||||||
"""
|
"""从两组文本数据创建embedding可视化服务"""
|
||||||
从两组文本数据创建embedding可视化服务
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts1: 第一组文本列表
|
|
||||||
texts2: 第二组文本列表
|
|
||||||
labels: 两组数据的标签
|
|
||||||
port: 服务器端口
|
|
||||||
host: 服务器主机
|
|
||||||
text_column: 文本列名称
|
|
||||||
model: 使用的嵌入模型
|
|
||||||
batch_size: 批处理大小
|
|
||||||
umap_args: UMAP参数字典
|
|
||||||
"""
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import embedding_atlas
|
import embedding_atlas
|
||||||
|
from embedding_atlas.projection import compute_text_projection
|
||||||
|
from embedding_atlas.server import make_server
|
||||||
|
from embedding_atlas.data_source import DataSource
|
||||||
|
from embedding_atlas.utils import Hasher
|
||||||
|
|
||||||
# 默认UMAP参数
|
# 默认UMAP参数
|
||||||
if umap_args is None:
|
if umap_args is None:
|
||||||
@@ -107,7 +98,6 @@ def create_embedding_service(
|
|||||||
|
|
||||||
# 3. 计算embeddings和投影
|
# 3. 计算embeddings和投影
|
||||||
try:
|
try:
|
||||||
from embedding_atlas.projection import compute_text_projection
|
|
||||||
compute_text_projection(
|
compute_text_projection(
|
||||||
combined_df,
|
combined_df,
|
||||||
text=text_column,
|
text=text_column,
|
||||||
@@ -125,7 +115,7 @@ def create_embedding_service(
|
|||||||
print(f"Error computing projections: {e}")
|
print(f"Error computing projections: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 4. 创建metadata
|
# 4. 创建metadata (修复: 添加database配置和正确的neighbors列名)
|
||||||
metadata = {
|
metadata = {
|
||||||
"columns": {
|
"columns": {
|
||||||
"id": "_row_index",
|
"id": "_row_index",
|
||||||
@@ -134,7 +124,11 @@ def create_embedding_service(
|
|||||||
"x": "projection_x",
|
"x": "projection_x",
|
||||||
"y": "projection_y"
|
"y": "projection_y"
|
||||||
},
|
},
|
||||||
"neighbors": "neighbors"
|
"neighbors": "__neighbors" # 修复: 使用双下划线
|
||||||
|
},
|
||||||
|
"database": {
|
||||||
|
"type": "wasm",
|
||||||
|
"load": True # 修复: 添加database配置
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,9 +192,29 @@ def visualize_csv_comparison(
|
|||||||
if label2 is None:
|
if label2 is None:
|
||||||
label2 = os.path.splitext(os.path.basename(csv2_path))[0]
|
label2 = os.path.splitext(os.path.basename(csv2_path))[0]
|
||||||
|
|
||||||
# Read CSV files
|
# Read CSV files with error handling
|
||||||
df1 = pd.read_csv(csv1_path)
|
try:
|
||||||
df2 = pd.read_csv(csv2_path)
|
df1 = pd.read_csv(csv1_path)
|
||||||
|
df2 = pd.read_csv(csv2_path)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
print(f"Error: CSV file not found - {e}")
|
||||||
|
return
|
||||||
|
except pd.errors.EmptyDataError:
|
||||||
|
print("Error: One of the CSV files is empty")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading CSV files: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate columns
|
||||||
|
if column1 not in df1.columns:
|
||||||
|
print(f"Error: Column '{column1}' not found in {csv1_path}")
|
||||||
|
print(f"Available columns: {list(df1.columns)}")
|
||||||
|
return
|
||||||
|
if column2 not in df2.columns:
|
||||||
|
print(f"Error: Column '{column2}' not found in {csv2_path}")
|
||||||
|
print(f"Available columns: {list(df2.columns)}")
|
||||||
|
return
|
||||||
|
|
||||||
# Add source column to identify origin of each data point
|
# Add source column to identify origin of each data point
|
||||||
df1['source'] = label1
|
df1['source'] = label1
|
||||||
@@ -213,7 +227,10 @@ def visualize_csv_comparison(
|
|||||||
# Combine dataframes
|
# Combine dataframes
|
||||||
combined_df = pd.concat([df1, df2], ignore_index=True)
|
combined_df = pd.concat([df1, df2], ignore_index=True)
|
||||||
|
|
||||||
# Determine which column to use for text projection (moved before projection check)
|
# Add _row_index column for embedding-atlas DataSource (moved to earlier)
|
||||||
|
combined_df['_row_index'] = range(len(combined_df))
|
||||||
|
|
||||||
|
# Determine which column to use for text projection
|
||||||
text_column = column1 if column1 in combined_df.columns else column2
|
text_column = column1 if column1 in combined_df.columns else column2
|
||||||
|
|
||||||
# Check if projection columns already exist
|
# Check if projection columns already exist
|
||||||
@@ -281,9 +298,6 @@ def visualize_csv_comparison(
|
|||||||
|
|
||||||
print(f"Visualization saved to {output_path}")
|
print(f"Visualization saved to {output_path}")
|
||||||
|
|
||||||
# Add _row_index column for embedding-atlas DataSource
|
|
||||||
combined_df['_row_index'] = range(len(combined_df))
|
|
||||||
|
|
||||||
# Save dataset for interactive viewing
|
# Save dataset for interactive viewing
|
||||||
parquet_path = "combined_dataset.parquet"
|
parquet_path = "combined_dataset.parquet"
|
||||||
combined_df.to_parquet(parquet_path, index=False)
|
combined_df.to_parquet(parquet_path, index=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user