diff --git a/scripts/admet_screen.py b/scripts/admet_screen.py new file mode 100644 index 0000000..25c8462 --- /dev/null +++ b/scripts/admet_screen.py @@ -0,0 +1,105 @@ +import pandas as pd +import numpy as np +import os, sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# === 配置 === +csv_path = "data/trpe_preds.csv" +props_with_thresholds = { + "AMES": 0.5, + "Carcinogens_Lagunin": 0.5, + "ClinTox": 0.5, + "DILI": 0.5, + "hERG": 0.5, + "Bioavailability_Ma": 0.5, + "HIA_Hou": 0.5, +} + +# 判定符号:含等号可切换成 <= / >= +toxic_op_le = "<" # 或 "<=" +adme_op_ge = ">" # 或 ">=" + +df = pd.read_csv(csv_path) +total = len(df) + +toxic_cols = ["AMES","Carcinogens_Lagunin","ClinTox","DILI","hERG"] +adme_cols = ["Bioavailability_Ma","HIA_Hou"] + +# 统一转数值(非数值转为 NaN) +for col in toxic_cols + adme_cols: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce") + +results = [] +os.makedirs("failed_by_rule", exist_ok=True) + +def apply_rule(series, thr, op): + if op == "<": return series < thr + if op == "<=": return series <= thr + if op == ">": return series > thr + if op == ">=": return series >= thr + raise ValueError("Unsupported op") + +for col, thr in props_with_thresholds.items(): + if col not in df.columns: + continue + + # 打印该属性的范围,方便 sanity check + col_min = float(np.nanmin(df[col].values)) if df[col].notna().any() else np.nan + col_max = float(np.nanmax(df[col].values)) if df[col].notna().any() else np.nan + + if col in toxic_cols: + mask = apply_rule(df[col], thr, toxic_op_le) + rule = f"{col} {toxic_op_le} {thr}" + else: + mask = apply_rule(df[col], thr, adme_op_ge) + rule = f"{col} {adme_op_ge} {thr}" + + passed = int(mask.sum()) + failed = int(total - passed) + + # 导出未通过的分子(含该列原值) + failed_df = df.loc[~mask, [col]].copy() + failed_df.to_csv(f"failed_by_rule/{col}_failed.csv", index=False) + + results.append({ + "属性": col, + "阈值": thr, + "规则": rule, + "最小值": col_min, + "最大值": col_max, + "通过数": passed, + "淘汰数": failed, + "是否全部淘汰": failed == total + }) + +res_df = pd.DataFrame(results).sort_values("淘汰数", ascending=False) +print(res_df) +res_df.to_csv("trpe_filter_diagnostics.csv", index=False) + +# === 最终整体过滤(交集) === +mask_all = pd.Series(True, index=df.index) + +for col, thr in props_with_thresholds.items(): + if col not in df.columns: + continue + if col in toxic_cols: + mask_all &= df[col] < thr # 毒理类必须 < 阈值 + else: + mask_all &= df[col] > thr # 吸收类必须 > 阈值 + +final_candidates = df[mask_all].copy() + +print("\n=== 最终筛选结果 ===") +print(f"总数: {total}") +print(f"符合所有标准的分子数: {len(final_candidates)}") + +if len(final_candidates) > 0: + print("符合条件的分子索引(行号):") + print(final_candidates.index.tolist()) + final_candidates.to_csv("trpe_final_candidates.csv", index=False) + + print("✅ 已保存到 trpe_final_candidates.csv") +else: + print("⚠️ 没有分子同时满足所有条件。") diff --git a/scripts/admet_screen2.py b/scripts/admet_screen2.py new file mode 100644 index 0000000..3130e44 --- /dev/null +++ b/scripts/admet_screen2.py @@ -0,0 +1,105 @@ +import pandas as pd +import numpy as np +import os, sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# === 配置 === +csv_path = "data/trpe_preds.csv" +props_with_thresholds = { + "AMES": 0.5, + "Carcinogens_Lagunin": 0.5, + "ClinTox": 0.5, + "DILI": 0.5, + "hERG": 0.5, + "Bioavailability_Ma": 0, + "HIA_Hou": 0, +} + +# 判定符号:含等号可切换成 <= / >= +toxic_op_le = "<" # 或 "<=" +adme_op_ge = ">" # 或 ">=" + +df = pd.read_csv(csv_path) +total = len(df) + +toxic_cols = ["AMES","Carcinogens_Lagunin","ClinTox","DILI","hERG"] +adme_cols = ["Bioavailability_Ma","HIA_Hou"] + +# 统一转数值(非数值转为 NaN) +for col in toxic_cols + adme_cols: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce") + +results = [] +os.makedirs("failed_by_rule", exist_ok=True) + +def apply_rule(series, thr, op): + if op == "<": return series < thr + if op == "<=": return series <= thr + if op == ">": return series > thr + if op == ">=": return series >= thr + raise ValueError("Unsupported op") + +for col, thr in props_with_thresholds.items(): + if col not in df.columns: + continue + + # 打印该属性的范围,方便 sanity check + col_min = float(np.nanmin(df[col].values)) if df[col].notna().any() else np.nan + col_max = float(np.nanmax(df[col].values)) if df[col].notna().any() else np.nan + + if col in toxic_cols: + mask = apply_rule(df[col], thr, toxic_op_le) + rule = f"{col} {toxic_op_le} {thr}" + else: + mask = apply_rule(df[col], thr, adme_op_ge) + rule = f"{col} {adme_op_ge} {thr}" + + passed = int(mask.sum()) + failed = int(total - passed) + + # 导出未通过的分子(含该列原值) + failed_df = df.loc[~mask, [col]].copy() + failed_df.to_csv(f"failed_by_rule/{col}_failed.csv", index=False) + + results.append({ + "属性": col, + "阈值": thr, + "规则": rule, + "最小值": col_min, + "最大值": col_max, + "通过数": passed, + "淘汰数": failed, + "是否全部淘汰": failed == total + }) + +res_df = pd.DataFrame(results).sort_values("淘汰数", ascending=False) +print(res_df) +res_df.to_csv("trpe_filter_diagnostics.csv", index=False) + +# === 最终整体过滤(交集) === +mask_all = pd.Series(True, index=df.index) + +for col, thr in props_with_thresholds.items(): + if col not in df.columns: + continue + if col in toxic_cols: + mask_all &= df[col] < thr # 毒理类必须 < 阈值 + else: + mask_all &= df[col] > thr # 吸收类必须 > 阈值 + +final_candidates = df[mask_all].copy() + +print("\n=== 最终筛选结果 ===") +print(f"总数: {total}") +print(f"符合所有标准的分子数: {len(final_candidates)}") + +if len(final_candidates) > 0: + print("符合条件的分子索引(行号):") + print(final_candidates.index.tolist()) + final_candidates.to_csv("trpe_final_candidates.csv", index=False) + + print("✅ 已保存到 trpe_final_candidates.csv") +else: + print("⚠️ 没有分子同时满足所有条件。")