Keep key validation outputs and analysis tables tracked directly, package analysis plot PNGs into a small tar.gz backup, and add analysis scripts plus tests so the stored results remain reproducible without flooding git with large image trees.
106 lines
3.1 KiB
Python
106 lines
3.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
分析 fragment_library.csv 中碎片的原子数分布。
|
||
"""
|
||
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
from rdkit import Chem
|
||
import os
|
||
import sys
|
||
|
||
|
||
def count_atoms(smiles):
|
||
"""计算 SMILES 中的原子数(不包括 dummy atom)"""
|
||
try:
|
||
mol = Chem.MolFromSmiles(smiles)
|
||
if mol is None:
|
||
return None
|
||
# 计算所有原子,但排除 dummy atom(原子序数为 0)
|
||
atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() != 0)
|
||
return atom_count
|
||
except:
|
||
return None
|
||
|
||
|
||
def main():
|
||
# 读取 CSV 文件
|
||
csv_path = "validation_output/fragment_library.csv"
|
||
if not os.path.exists(csv_path):
|
||
print(f"文件不存在: {csv_path}")
|
||
sys.exit(1)
|
||
|
||
df = pd.read_csv(csv_path)
|
||
print(f"总碎片数: {len(df)}")
|
||
|
||
# 计算原子数
|
||
df["atom_count"] = df["fragment_smiles_plain"].apply(count_atoms)
|
||
|
||
# 检查是否有无效的 SMILES
|
||
invalid_count = df["atom_count"].isna().sum()
|
||
if invalid_count > 0:
|
||
print(f"无效 SMILES 数量: {invalid_count}")
|
||
|
||
# 过滤掉无效的 SMILES
|
||
df_valid = df.dropna(subset=["atom_count"])
|
||
print(f"有效碎片数: {len(df_valid)}")
|
||
|
||
# 统计描述
|
||
print("\n原子数分布统计:")
|
||
print(df_valid["atom_count"].describe())
|
||
|
||
# 绘制分布图
|
||
plt.figure(figsize=(10, 6))
|
||
sns.histplot(
|
||
df_valid["atom_count"],
|
||
bins=range(0, int(df_valid["atom_count"].max()) + 2),
|
||
kde=False,
|
||
)
|
||
plt.title("Fragment Atom Count Distribution")
|
||
plt.xlabel("Atom Count")
|
||
plt.ylabel("Frequency")
|
||
plt.grid(True, alpha=0.3)
|
||
|
||
# 保存图片
|
||
output_dir = "validation_output"
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
plot_path = os.path.join(output_dir, "fragment_atom_count_distribution.png")
|
||
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
|
||
print(f"\n分布图已保存至: {plot_path}")
|
||
|
||
# 显示图片(如果在交互式环境中)
|
||
try:
|
||
plt.show()
|
||
except:
|
||
pass
|
||
|
||
# 分析不同原子数的碎片数量
|
||
atom_count_stats = df_valid["atom_count"].value_counts().sort_index()
|
||
print("\n不同原子数的碎片数量:")
|
||
for atom_count, count in atom_count_stats.items():
|
||
print(f" 原子数 {atom_count}: {count} 个碎片")
|
||
|
||
# 计算累积百分比
|
||
total_fragments = len(df_valid)
|
||
cumulative = 0
|
||
print("\n累积分布:")
|
||
for atom_count, count in atom_count_stats.items():
|
||
cumulative += count
|
||
percentage = (cumulative / total_fragments) * 100
|
||
print(f" 原子数 <= {atom_count}: {cumulative} 个碎片 ({percentage:.2f}%)")
|
||
|
||
# 建议筛选标准
|
||
print("\n建议筛选标准:")
|
||
# 例如,过滤掉原子数小于 3 的碎片
|
||
min_atom_count = 3
|
||
filtered_count = len(df_valid[df_valid["atom_count"] >= min_atom_count])
|
||
filtered_percentage = (filtered_count / total_fragments) * 100
|
||
print(
|
||
f" 如果过滤掉原子数 < {min_atom_count} 的碎片,剩余 {filtered_count} 个碎片 ({filtered_percentage:.2f}%)"
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|