#!/usr/bin/env python3 """ 分析 fragment_library.csv 中碎片的原子数分布。 """ import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from rdkit import Chem import os import sys def count_atoms(smiles): """计算 SMILES 中的原子数(不包括 dummy atom)""" try: mol = Chem.MolFromSmiles(smiles) if mol is None: return None # 计算所有原子,但排除 dummy atom(原子序数为 0) atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() != 0) return atom_count except: return None def main(): # 读取 CSV 文件 csv_path = "validation_output/fragment_library.csv" if not os.path.exists(csv_path): print(f"文件不存在: {csv_path}") sys.exit(1) df = pd.read_csv(csv_path) print(f"总碎片数: {len(df)}") # 计算原子数 df["atom_count"] = df["fragment_smiles_plain"].apply(count_atoms) # 检查是否有无效的 SMILES invalid_count = df["atom_count"].isna().sum() if invalid_count > 0: print(f"无效 SMILES 数量: {invalid_count}") # 过滤掉无效的 SMILES df_valid = df.dropna(subset=["atom_count"]) print(f"有效碎片数: {len(df_valid)}") # 统计描述 print("\n原子数分布统计:") print(df_valid["atom_count"].describe()) # 绘制分布图 plt.figure(figsize=(10, 6)) sns.histplot( df_valid["atom_count"], bins=range(0, int(df_valid["atom_count"].max()) + 2), kde=False, ) plt.title("Fragment Atom Count Distribution") plt.xlabel("Atom Count") plt.ylabel("Frequency") plt.grid(True, alpha=0.3) # 保存图片 output_dir = "validation_output" os.makedirs(output_dir, exist_ok=True) plot_path = os.path.join(output_dir, "fragment_atom_count_distribution.png") plt.savefig(plot_path, dpi=300, bbox_inches="tight") print(f"\n分布图已保存至: {plot_path}") # 显示图片(如果在交互式环境中) try: plt.show() except: pass # 分析不同原子数的碎片数量 atom_count_stats = df_valid["atom_count"].value_counts().sort_index() print("\n不同原子数的碎片数量:") for atom_count, count in atom_count_stats.items(): print(f" 原子数 {atom_count}: {count} 个碎片") # 计算累积百分比 total_fragments = len(df_valid) cumulative = 0 print("\n累积分布:") for atom_count, count in atom_count_stats.items(): cumulative += count percentage = (cumulative / total_fragments) * 100 print(f" 原子数 <= {atom_count}: {cumulative} 个碎片 ({percentage:.2f}%)") # 建议筛选标准 print("\n建议筛选标准:") # 例如,过滤掉原子数小于 3 的碎片 min_atom_count = 3 filtered_count = len(df_valid[df_valid["atom_count"] >= min_atom_count]) filtered_percentage = (filtered_count / total_fragments) * 100 print( f" 如果过滤掉原子数 < {min_atom_count} 的碎片,剩余 {filtered_count} 个碎片 ({filtered_percentage:.2f}%)" ) if __name__ == "__main__": main()