feat(validation): archive key result assets
Keep key validation outputs and analysis tables tracked directly, package analysis plot PNGs into a small tar.gz backup, and add analysis scripts plus tests so the stored results remain reproducible without flooding git with large image trees.
This commit is contained in:
105
scripts/analyze_fragment_atom_counts.py
Normal file
105
scripts/analyze_fragment_atom_counts.py
Normal file
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
分析 fragment_library.csv 中碎片的原子数分布。
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from rdkit import Chem
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def count_atoms(smiles):
|
||||
"""计算 SMILES 中的原子数(不包括 dummy atom)"""
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
# 计算所有原子,但排除 dummy atom(原子序数为 0)
|
||||
atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() != 0)
|
||||
return atom_count
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
# 读取 CSV 文件
|
||||
csv_path = "validation_output/fragment_library.csv"
|
||||
if not os.path.exists(csv_path):
|
||||
print(f"文件不存在: {csv_path}")
|
||||
sys.exit(1)
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
print(f"总碎片数: {len(df)}")
|
||||
|
||||
# 计算原子数
|
||||
df["atom_count"] = df["fragment_smiles_plain"].apply(count_atoms)
|
||||
|
||||
# 检查是否有无效的 SMILES
|
||||
invalid_count = df["atom_count"].isna().sum()
|
||||
if invalid_count > 0:
|
||||
print(f"无效 SMILES 数量: {invalid_count}")
|
||||
|
||||
# 过滤掉无效的 SMILES
|
||||
df_valid = df.dropna(subset=["atom_count"])
|
||||
print(f"有效碎片数: {len(df_valid)}")
|
||||
|
||||
# 统计描述
|
||||
print("\n原子数分布统计:")
|
||||
print(df_valid["atom_count"].describe())
|
||||
|
||||
# 绘制分布图
|
||||
plt.figure(figsize=(10, 6))
|
||||
sns.histplot(
|
||||
df_valid["atom_count"],
|
||||
bins=range(0, int(df_valid["atom_count"].max()) + 2),
|
||||
kde=False,
|
||||
)
|
||||
plt.title("Fragment Atom Count Distribution")
|
||||
plt.xlabel("Atom Count")
|
||||
plt.ylabel("Frequency")
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
# 保存图片
|
||||
output_dir = "validation_output"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
plot_path = os.path.join(output_dir, "fragment_atom_count_distribution.png")
|
||||
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
|
||||
print(f"\n分布图已保存至: {plot_path}")
|
||||
|
||||
# 显示图片(如果在交互式环境中)
|
||||
try:
|
||||
plt.show()
|
||||
except:
|
||||
pass
|
||||
|
||||
# 分析不同原子数的碎片数量
|
||||
atom_count_stats = df_valid["atom_count"].value_counts().sort_index()
|
||||
print("\n不同原子数的碎片数量:")
|
||||
for atom_count, count in atom_count_stats.items():
|
||||
print(f" 原子数 {atom_count}: {count} 个碎片")
|
||||
|
||||
# 计算累积百分比
|
||||
total_fragments = len(df_valid)
|
||||
cumulative = 0
|
||||
print("\n累积分布:")
|
||||
for atom_count, count in atom_count_stats.items():
|
||||
cumulative += count
|
||||
percentage = (cumulative / total_fragments) * 100
|
||||
print(f" 原子数 <= {atom_count}: {cumulative} 个碎片 ({percentage:.2f}%)")
|
||||
|
||||
# 建议筛选标准
|
||||
print("\n建议筛选标准:")
|
||||
# 例如,过滤掉原子数小于 3 的碎片
|
||||
min_atom_count = 3
|
||||
filtered_count = len(df_valid[df_valid["atom_count"] >= min_atom_count])
|
||||
filtered_percentage = (filtered_count / total_fragments) * 100
|
||||
print(
|
||||
f" 如果过滤掉原子数 < {min_atom_count} 的碎片,剩余 {filtered_count} 个碎片 ({filtered_percentage:.2f}%)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user