Files
macro_split/scripts/analyze_fragment_atom_counts.py
lingyuzeng 8071a141ee feat(validation): archive key result assets
Keep key validation outputs and analysis tables tracked directly,
package analysis plot PNGs into a small tar.gz backup, and add
analysis scripts plus tests so the stored results remain
reproducible without flooding git with large image trees.
2026-03-19 21:34:27 +08:00

106 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
分析 fragment_library.csv 中碎片的原子数分布。
"""
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from rdkit import Chem
import os
import sys
def count_atoms(smiles):
"""计算 SMILES 中的原子数(不包括 dummy atom"""
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
# 计算所有原子,但排除 dummy atom原子序数为 0
atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() != 0)
return atom_count
except:
return None
def main():
# 读取 CSV 文件
csv_path = "validation_output/fragment_library.csv"
if not os.path.exists(csv_path):
print(f"文件不存在: {csv_path}")
sys.exit(1)
df = pd.read_csv(csv_path)
print(f"总碎片数: {len(df)}")
# 计算原子数
df["atom_count"] = df["fragment_smiles_plain"].apply(count_atoms)
# 检查是否有无效的 SMILES
invalid_count = df["atom_count"].isna().sum()
if invalid_count > 0:
print(f"无效 SMILES 数量: {invalid_count}")
# 过滤掉无效的 SMILES
df_valid = df.dropna(subset=["atom_count"])
print(f"有效碎片数: {len(df_valid)}")
# 统计描述
print("\n原子数分布统计:")
print(df_valid["atom_count"].describe())
# 绘制分布图
plt.figure(figsize=(10, 6))
sns.histplot(
df_valid["atom_count"],
bins=range(0, int(df_valid["atom_count"].max()) + 2),
kde=False,
)
plt.title("Fragment Atom Count Distribution")
plt.xlabel("Atom Count")
plt.ylabel("Frequency")
plt.grid(True, alpha=0.3)
# 保存图片
output_dir = "validation_output"
os.makedirs(output_dir, exist_ok=True)
plot_path = os.path.join(output_dir, "fragment_atom_count_distribution.png")
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
print(f"\n分布图已保存至: {plot_path}")
# 显示图片(如果在交互式环境中)
try:
plt.show()
except:
pass
# 分析不同原子数的碎片数量
atom_count_stats = df_valid["atom_count"].value_counts().sort_index()
print("\n不同原子数的碎片数量:")
for atom_count, count in atom_count_stats.items():
print(f" 原子数 {atom_count}: {count} 个碎片")
# 计算累积百分比
total_fragments = len(df_valid)
cumulative = 0
print("\n累积分布:")
for atom_count, count in atom_count_stats.items():
cumulative += count
percentage = (cumulative / total_fragments) * 100
print(f" 原子数 <= {atom_count}: {cumulative} 个碎片 ({percentage:.2f}%)")
# 建议筛选标准
print("\n建议筛选标准:")
# 例如,过滤掉原子数小于 3 的碎片
min_atom_count = 3
filtered_count = len(df_valid[df_valid["atom_count"] >= min_atom_count])
filtered_percentage = (filtered_count / total_fragments) * 100
print(
f" 如果过滤掉原子数 < {min_atom_count} 的碎片,剩余 {filtered_count} 个碎片 ({filtered_percentage:.2f}%)"
)
if __name__ == "__main__":
main()