106 lines
3.0 KiB
Python
106 lines
3.0 KiB
Python
import pandas as pd
|
||
import numpy as np
|
||
import os, sys
|
||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
# === 配置 ===
|
||
csv_path = "data/trpe_preds.csv"
|
||
props_with_thresholds = {
|
||
"AMES": 0.5,
|
||
"Carcinogens_Lagunin": 0.5,
|
||
"ClinTox": 0.5,
|
||
"DILI": 0.5,
|
||
"hERG": 0.5,
|
||
"Bioavailability_Ma": 0,
|
||
"HIA_Hou": 0,
|
||
}
|
||
|
||
# 判定符号:含等号可切换成 <= / >=
|
||
toxic_op_le = "<" # 或 "<="
|
||
adme_op_ge = ">" # 或 ">="
|
||
|
||
df = pd.read_csv(csv_path)
|
||
total = len(df)
|
||
|
||
toxic_cols = ["AMES","Carcinogens_Lagunin","ClinTox","DILI","hERG"]
|
||
adme_cols = ["Bioavailability_Ma","HIA_Hou"]
|
||
|
||
# 统一转数值(非数值转为 NaN)
|
||
for col in toxic_cols + adme_cols:
|
||
if col in df.columns:
|
||
df[col] = pd.to_numeric(df[col], errors="coerce")
|
||
|
||
results = []
|
||
os.makedirs("failed_by_rule", exist_ok=True)
|
||
|
||
def apply_rule(series, thr, op):
|
||
if op == "<": return series < thr
|
||
if op == "<=": return series <= thr
|
||
if op == ">": return series > thr
|
||
if op == ">=": return series >= thr
|
||
raise ValueError("Unsupported op")
|
||
|
||
for col, thr in props_with_thresholds.items():
|
||
if col not in df.columns:
|
||
continue
|
||
|
||
# 打印该属性的范围,方便 sanity check
|
||
col_min = float(np.nanmin(df[col].values)) if df[col].notna().any() else np.nan
|
||
col_max = float(np.nanmax(df[col].values)) if df[col].notna().any() else np.nan
|
||
|
||
if col in toxic_cols:
|
||
mask = apply_rule(df[col], thr, toxic_op_le)
|
||
rule = f"{col} {toxic_op_le} {thr}"
|
||
else:
|
||
mask = apply_rule(df[col], thr, adme_op_ge)
|
||
rule = f"{col} {adme_op_ge} {thr}"
|
||
|
||
passed = int(mask.sum())
|
||
failed = int(total - passed)
|
||
|
||
# 导出未通过的分子(含该列原值)
|
||
failed_df = df.loc[~mask, [col]].copy()
|
||
failed_df.to_csv(f"failed_by_rule/{col}_failed.csv", index=False)
|
||
|
||
results.append({
|
||
"属性": col,
|
||
"阈值": thr,
|
||
"规则": rule,
|
||
"最小值": col_min,
|
||
"最大值": col_max,
|
||
"通过数": passed,
|
||
"淘汰数": failed,
|
||
"是否全部淘汰": failed == total
|
||
})
|
||
|
||
res_df = pd.DataFrame(results).sort_values("淘汰数", ascending=False)
|
||
print(res_df)
|
||
res_df.to_csv("trpe_filter_diagnostics.csv", index=False)
|
||
|
||
# === 最终整体过滤(交集) ===
|
||
mask_all = pd.Series(True, index=df.index)
|
||
|
||
for col, thr in props_with_thresholds.items():
|
||
if col not in df.columns:
|
||
continue
|
||
if col in toxic_cols:
|
||
mask_all &= df[col] < thr # 毒理类必须 < 阈值
|
||
else:
|
||
mask_all &= df[col] > thr # 吸收类必须 > 阈值
|
||
|
||
final_candidates = df[mask_all].copy()
|
||
|
||
print("\n=== 最终筛选结果 ===")
|
||
print(f"总数: {total}")
|
||
print(f"符合所有标准的分子数: {len(final_candidates)}")
|
||
|
||
if len(final_candidates) > 0:
|
||
print("符合条件的分子索引(行号):")
|
||
print(final_candidates.index.tolist())
|
||
final_candidates.to_csv("trpe_final_candidates.csv", index=False)
|
||
|
||
print("✅ 已保存到 trpe_final_candidates.csv")
|
||
else:
|
||
print("⚠️ 没有分子同时满足所有条件。")
|