import pandas as pd import numpy as np import os, sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # === 配置 === csv_path = "data/trpe_preds.csv" props_with_thresholds = { "AMES": 0.5, "Carcinogens_Lagunin": 0.5, "ClinTox": 0.5, "DILI": 0.5, "hERG": 0.5, "Bioavailability_Ma": 0.5, "HIA_Hou": 0.5, } # 判定符号:含等号可切换成 <= / >= toxic_op_le = "<" # 或 "<=" adme_op_ge = ">" # 或 ">=" df = pd.read_csv(csv_path) total = len(df) toxic_cols = ["AMES","Carcinogens_Lagunin","ClinTox","DILI","hERG"] adme_cols = ["Bioavailability_Ma","HIA_Hou"] # 统一转数值(非数值转为 NaN) for col in toxic_cols + adme_cols: if col in df.columns: df[col] = pd.to_numeric(df[col], errors="coerce") results = [] os.makedirs("failed_by_rule", exist_ok=True) def apply_rule(series, thr, op): if op == "<": return series < thr if op == "<=": return series <= thr if op == ">": return series > thr if op == ">=": return series >= thr raise ValueError("Unsupported op") for col, thr in props_with_thresholds.items(): if col not in df.columns: continue # 打印该属性的范围,方便 sanity check col_min = float(np.nanmin(df[col].values)) if df[col].notna().any() else np.nan col_max = float(np.nanmax(df[col].values)) if df[col].notna().any() else np.nan if col in toxic_cols: mask = apply_rule(df[col], thr, toxic_op_le) rule = f"{col} {toxic_op_le} {thr}" else: mask = apply_rule(df[col], thr, adme_op_ge) rule = f"{col} {adme_op_ge} {thr}" passed = int(mask.sum()) failed = int(total - passed) # 导出未通过的分子(含该列原值) failed_df = df.loc[~mask, [col]].copy() failed_df.to_csv(f"failed_by_rule/{col}_failed.csv", index=False) results.append({ "属性": col, "阈值": thr, "规则": rule, "最小值": col_min, "最大值": col_max, "通过数": passed, "淘汰数": failed, "是否全部淘汰": failed == total }) res_df = pd.DataFrame(results).sort_values("淘汰数", ascending=False) print(res_df) res_df.to_csv("trpe_filter_diagnostics.csv", index=False) # === 最终整体过滤(交集) === mask_all = pd.Series(True, index=df.index) for col, thr in props_with_thresholds.items(): if col not in df.columns: continue if col in toxic_cols: mask_all &= df[col] < thr # 毒理类必须 < 阈值 else: mask_all &= df[col] > thr # 吸收类必须 > 阈值 final_candidates = df[mask_all].copy() print("\n=== 最终筛选结果 ===") print(f"总数: {total}") print(f"符合所有标准的分子数: {len(final_candidates)}") if len(final_candidates) > 0: print("符合条件的分子索引(行号):") print(final_candidates.index.tolist()) final_candidates.to_csv("trpe_final_candidates.csv", index=False) print("✅ 已保存到 trpe_final_candidates.csv") else: print("⚠️ 没有分子同时满足所有条件。")