Files
admet-ai/scripts/admet_screen.py
2025-08-28 22:43:59 +08:00

106 lines
3.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import pandas as pd
import numpy as np
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# === 配置 ===
csv_path = "data/trpe_preds.csv"
props_with_thresholds = {
"AMES": 0.5,
"Carcinogens_Lagunin": 0.5,
"ClinTox": 0.5,
"DILI": 0.5,
"hERG": 0.5,
"Bioavailability_Ma": 0.5,
"HIA_Hou": 0.5,
}
# 判定符号:含等号可切换成 <= / >=
toxic_op_le = "<" # 或 "<="
adme_op_ge = ">" # 或 ">="
df = pd.read_csv(csv_path)
total = len(df)
toxic_cols = ["AMES","Carcinogens_Lagunin","ClinTox","DILI","hERG"]
adme_cols = ["Bioavailability_Ma","HIA_Hou"]
# 统一转数值(非数值转为 NaN
for col in toxic_cols + adme_cols:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce")
results = []
os.makedirs("failed_by_rule", exist_ok=True)
def apply_rule(series, thr, op):
if op == "<": return series < thr
if op == "<=": return series <= thr
if op == ">": return series > thr
if op == ">=": return series >= thr
raise ValueError("Unsupported op")
for col, thr in props_with_thresholds.items():
if col not in df.columns:
continue
# 打印该属性的范围,方便 sanity check
col_min = float(np.nanmin(df[col].values)) if df[col].notna().any() else np.nan
col_max = float(np.nanmax(df[col].values)) if df[col].notna().any() else np.nan
if col in toxic_cols:
mask = apply_rule(df[col], thr, toxic_op_le)
rule = f"{col} {toxic_op_le} {thr}"
else:
mask = apply_rule(df[col], thr, adme_op_ge)
rule = f"{col} {adme_op_ge} {thr}"
passed = int(mask.sum())
failed = int(total - passed)
# 导出未通过的分子(含该列原值)
failed_df = df.loc[~mask, [col]].copy()
failed_df.to_csv(f"failed_by_rule/{col}_failed.csv", index=False)
results.append({
"属性": col,
"阈值": thr,
"规则": rule,
"最小值": col_min,
"最大值": col_max,
"通过数": passed,
"淘汰数": failed,
"是否全部淘汰": failed == total
})
res_df = pd.DataFrame(results).sort_values("淘汰数", ascending=False)
print(res_df)
res_df.to_csv("trpe_filter_diagnostics.csv", index=False)
# === 最终整体过滤(交集) ===
mask_all = pd.Series(True, index=df.index)
for col, thr in props_with_thresholds.items():
if col not in df.columns:
continue
if col in toxic_cols:
mask_all &= df[col] < thr # 毒理类必须 < 阈值
else:
mask_all &= df[col] > thr # 吸收类必须 > 阈值
final_candidates = df[mask_all].copy()
print("\n=== 最终筛选结果 ===")
print(f"总数: {total}")
print(f"符合所有标准的分子数: {len(final_candidates)}")
if len(final_candidates) > 0:
print("符合条件的分子索引(行号):")
print(final_candidates.index.tolist())
final_candidates.to_csv("trpe_final_candidates.csv", index=False)
print("✅ 已保存到 trpe_final_candidates.csv")
else:
print("⚠️ 没有分子同时满足所有条件。")