重构项目结构并更新README.md

1. 重构目录结构:
   - 创建src/visualization模块用于存放可视化相关功能
   - 移动script/visualize_csv_comparison.py到src/visualization/comparison.py
   - 创建src/visualization/__init__.py导出主要函数
   - 整理script目录,按功能分类存放脚本文件

2. 更新README.md:
   - 添加CSV文件比较可视化部分
   - 提供Python API和命令行使用方法说明
   - 描述功能特点和使用示例

3. 更新模块引用:
   - 修正comparison.py中的模块引用路径
   - 更新命令行帮助信息中的使用示例
This commit is contained in:
2025-10-23 17:55:36 +08:00
parent 9f0a0fbcdc
commit bbf1746046
7 changed files with 358 additions and 0 deletions

View File

@@ -0,0 +1,186 @@
#!/usr/bin/env python3
"""Augment a CSV with ECFP4 binary fingerprints and Tanimoto neighbor summaries."""
from __future__ import annotations
import argparse
import pathlib
from dataclasses import dataclass, field
from typing import Iterable, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
from rdkit import Chem, DataStructs
from rdkit.Chem import rdFingerprintGenerator
@dataclass
class ECFP4Generator:
"""Generate ECFP4 (Morgan radius 2) fingerprints as RDKit bit vectors."""
n_bits: int = 2048
radius: int = 2
include_chirality: bool = True
generator: rdFingerprintGenerator.MorganGenerator = field(init=False)
def __post_init__(self) -> None:
self.generator = rdFingerprintGenerator.GetMorganGenerator(
radius=self.radius,
fpSize=self.n_bits,
includeChirality=self.include_chirality,
)
def fingerprint(self, smiles: str) -> Optional[DataStructs.ExplicitBitVect]:
if not smiles:
return None
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
try:
return self.generator.GetFingerprint(mol)
except Exception:
return None
def to_binary_string(self, fp: DataStructs.ExplicitBitVect) -> str:
arr = np.zeros((self.n_bits,), dtype=np.uint8)
DataStructs.ConvertToNumpyArray(fp, arr)
return ''.join(arr.astype(str))
def tanimoto_top_k(
fingerprints: Sequence[Optional[DataStructs.ExplicitBitVect]],
ids: Sequence[str],
top_k: int = 5,
) -> List[str]:
"""Return semicolon-delimited Tanimoto summaries for each fingerprint."""
valid_indices = [idx for idx, fp in enumerate(fingerprints) if fp is not None]
valid_fps = [fingerprints[idx] for idx in valid_indices]
index_lookup = {pos: original for pos, original in enumerate(valid_indices)}
summaries = [''] * len(fingerprints)
if not valid_fps:
return summaries
for pos, fp in enumerate(valid_fps):
sims = DataStructs.BulkTanimotoSimilarity(fp, valid_fps)
ranked: List[Tuple[int, float]] = []
for other_pos, score in enumerate(sims):
if other_pos == pos:
continue
ranked.append((other_pos, score))
ranked.sort(key=lambda item: item[1], reverse=True)
top_entries = []
for other_pos, score in ranked[:top_k]:
original_idx = index_lookup[other_pos]
top_entries.append(f"{ids[original_idx]}:{score:.4f}")
summaries[index_lookup[pos]] = ';'.join(top_entries)
return summaries
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input_csv", type=pathlib.Path, help="Source CSV with SMILES")
parser.add_argument(
"--output",
type=pathlib.Path,
default=None,
help="Destination file (default: <input>_with_ecfp4.parquet)",
)
parser.add_argument(
"--smiles-column",
default="smiles",
help="Name of the column containing SMILES (default: smiles)",
)
parser.add_argument(
"--id-column",
default="generated_id",
help="Column used to label Tanimoto neighbors (default: generated_id)",
)
parser.add_argument(
"--top-k",
type=int,
default=5,
help="Number of nearest neighbors to report in Tanimoto summaries (default: 5)",
)
parser.add_argument(
"--format",
choices=("parquet", "csv", "auto"),
default="parquet",
help="Output format; defaults to parquet unless overridden or inferred from --output",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not args.input_csv.exists():
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
def resolve_output_path() -> pathlib.Path:
if args.output is not None:
return args.output
suffix = ".parquet" if args.format in ("parquet", "auto") else ".csv"
return args.input_csv.with_name(f"{args.input_csv.stem}_with_ecfp4{suffix}")
def resolve_format(path: pathlib.Path) -> str:
suffix = path.suffix.lower()
if suffix in {".parquet", ".pq"}:
return "parquet"
if suffix == ".csv":
return "csv"
if args.format == "parquet":
return "parquet"
if args.format == "csv":
return "csv"
raise ValueError(
"无法根据输出文件推断格式,请为 --output 指定 .parquet/.csv 后缀或使用 --format",
)
output_path = resolve_output_path()
output_format = resolve_format(output_path)
if args.input_csv.resolve() == output_path.resolve():
raise ValueError("Output path must differ from input path to avoid overwriting input.")
df = pd.read_csv(args.input_csv)
if args.smiles_column not in df.columns:
raise ValueError(f"Column '{args.smiles_column}' not found in input data")
smiles_series = df[args.smiles_column].fillna('')
if args.id_column in df.columns:
ids = df[args.id_column].astype(str).tolist()
else:
ids = [f"D{idx:06d}" for idx in range(1, len(df) + 1)]
df[args.id_column] = ids
generator = ECFP4Generator()
fingerprints: List[Optional[DataStructs.ExplicitBitVect]] = []
binary_repr: List[str] = []
for smiles in smiles_series:
fp = generator.fingerprint(smiles)
fingerprints.append(fp)
binary_repr.append(generator.to_binary_string(fp) if fp is not None else '')
df['ecfp4_binary'] = binary_repr
df['tanimoto_top_neighbors'] = tanimoto_top_k(fingerprints, ids, top_k=args.top_k)
if output_format == "parquet":
df.to_parquet(output_path, index=False)
else:
df.to_csv(output_path, index=False)
print(f"Wrote augmented data with {len(df)} rows to {output_path} ({output_format})")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,127 @@
#!/usr/bin/env python3
"""Add generated IDs and macrolactone ring size annotations to a CSV file."""
import argparse
import pathlib
import warnings
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Optional, Set, Tuple
import pandas as pd
from rdkit import Chem
@dataclass
class MacrolactoneRingDetector:
"""Detect macrolactone rings in SMILES strings via SMARTS patterns."""
min_size: int = 12
max_size: int = 20
patterns: Dict[int, Optional[Chem.Mol]] = field(init=False, default_factory=dict)
def __post_init__(self) -> None:
self.patterns = {
size: Chem.MolFromSmarts(f"[r{size}]([#8][#6](=[#8]))")
for size in range(self.min_size, self.max_size + 1)
}
def ring_sizes(self, smiles: str) -> List[int]:
"""Return a sorted list of macrolactone ring sizes present in the SMILES."""
if not smiles:
return []
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return []
ring_atoms = mol.GetRingInfo().AtomRings()
if not ring_atoms:
return []
matched_rings: Set[Tuple[int, ...]] = set()
matched_sizes: Set[int] = set()
for size, query in self.patterns.items():
if query is None:
continue
for match in mol.GetSubstructMatches(query, uniquify=True):
ring = self._pick_ring(match, ring_atoms, size)
if ring and ring not in matched_rings:
matched_rings.add(ring)
matched_sizes.add(size)
sizes = sorted(matched_sizes)
if len(sizes) > 1:
warnings.warn(
"Multiple macrolactone ring sizes detected",
RuntimeWarning,
stacklevel=2,
)
print(f"Multiple ring sizes {sizes} for SMILES: {smiles}")
return sizes
@staticmethod
def _pick_ring(
match: Tuple[int, ...], rings: Iterable[Tuple[int, ...]], expected_size: int
) -> Optional[Tuple[int, ...]]:
ring_atom = match[0]
for ring in rings:
if len(ring) == expected_size and ring_atom in ring:
return tuple(sorted(ring))
return None
def add_columns(df: pd.DataFrame, detector: MacrolactoneRingDetector) -> pd.DataFrame:
result = df.copy()
result["generated_id"] = [f"D{index:06d}" for index in range(1, len(result) + 1)]
smiles_series = result["smiles"].fillna("") if "smiles" in result.columns else pd.Series(
[""] * len(result), index=result.index
)
def format_sizes(smiles: str) -> str:
sizes = detector.ring_sizes(smiles)
return ";".join(str(size) for size in sizes) if sizes else ""
result["macrocycle_ring_sizes"] = smiles_series.apply(format_sizes)
return result
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input_csv", type=pathlib.Path, help="Path to the source CSV file")
parser.add_argument(
"--output",
type=pathlib.Path,
default=None,
help="Destination for the augmented CSV (default: <input>_with_macrocycles.csv)",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not args.input_csv.exists():
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
output_path = args.output or args.input_csv.with_name(
f"{args.input_csv.stem}_with_macrocycles.csv"
)
if args.input_csv.resolve() == output_path.resolve():
raise ValueError("Output path must differ from input path to avoid overwriting.")
df = pd.read_csv(args.input_csv)
detector = MacrolactoneRingDetector()
augmented = add_columns(df, detector)
augmented.to_csv(output_path, index=False)
print(f"Wrote augmented data with {len(augmented)} rows to {output_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,70 @@
#!/usr/bin/env python3
"""Merge split CSVs into a single file with split source labels."""
import argparse
from pathlib import Path
import pandas as pd
# Mapping of split CSV filenames to numeric labels for the source column
SPLIT_LABELS = {
"split_test.csv": 1,
"split_train.csv": 2,
"split_val.csv": 3,
}
DEFAULT_COLUMN_NAME = "split_source"
def parse_args() -> argparse.Namespace:
repo_root = Path(__file__).resolve().parent.parent
parser = argparse.ArgumentParser(
description=
"Combine split_*.csv files from splits_v2 and label their origin with integers."
)
parser.add_argument(
"--input-dir",
type=Path,
default=repo_root / "splits_v2",
help="Directory containing split_*.csv files (default: %(default)s)",
)
parser.add_argument(
"--output",
type=Path,
default=repo_root / "data" / "merged_splits.csv",
help="Destination CSV path (default: %(default)s)",
)
parser.add_argument(
"--column-name",
default=DEFAULT_COLUMN_NAME,
help="Name for the source column (default: %(default)s)",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not args.input_dir.is_dir():
raise SystemExit(f"Input directory not found: {args.input_dir}")
frames = []
for filename, label in SPLIT_LABELS.items():
csv_path = args.input_dir / filename
if not csv_path.is_file():
raise SystemExit(f"Missing expected split file: {csv_path}")
df = pd.read_csv(csv_path)
df[args.column_name] = label
frames.append(df)
if not frames:
raise SystemExit("No split CSV files were loaded.")
merged = pd.concat(frames, ignore_index=True)
args.output.parent.mkdir(parents=True, exist_ok=True)
merged.to_csv(args.output, index=False)
print(f"Merged {len(frames)} files with {len(merged)} rows into {args.output}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,304 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本
- 读取 CSV至少包含 SELFIES 或 SMILES 之一,建议还包含 qed
- 计算 SMILES若仅有 SELFIES构建 RDKit Mol
- 计算 Bemis-Murcko Scaffold、分子量等基本属性
- 以 scaffold 为分组单元进行分层划分train/val/test
- 对齐 QED/MW 分布(分箱 + 贪心平衡)
- 输出三份 CSVsplit_train.csv / split_val.csv / split_test.csv
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本(修正版)
- 关键修复:分配策略改为【容量优先 + 分布对齐】,避免 TRAIN=0 的偏置
"""
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
结构感知 + 分布对齐 的 DrugBank 数据集划分脚本(修正版)
- 关键修复:分配策略改为【容量优先 + 分布对齐】,避免 TRAIN=0 的偏置
"""
import argparse
from pathlib import Path
import math
import hashlib
from collections import defaultdict, Counter
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import Descriptors
try:
import selfies as sf
except Exception:
sf = None
def to_smiles_from_selfies(s):
if pd.isna(s) or not isinstance(s, str) or not s.strip():
return None
if sf is None:
return None
try:
smi = sf.decoder(s)
mol = Chem.MolFromSmiles(smi)
if mol is None:
return None
return Chem.MolToSmiles(mol)
except Exception:
return None
def canonicalize_smiles(s):
if pd.isna(s) or not isinstance(s, str) or not s.strip():
return None
mol = Chem.MolFromSmiles(s)
if mol is None:
return None
return Chem.MolToSmiles(mol)
def bemis_murcko_scaffold(smiles):
if smiles is None:
return None
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
try:
scaf = MurckoScaffold.GetScaffoldForMol(mol)
if scaf is None:
return None
return Chem.MolToSmiles(scaf)
except Exception:
return None
def calc_basic_props(smiles):
mol = Chem.MolFromSmiles(smiles) if smiles else None
if mol is None:
return (np.nan, np.nan, np.nan, np.nan)
mw = Descriptors.MolWt(mol)
tpsa = Descriptors.TPSA(mol)
logp = Descriptors.MolLogP(mol)
heavy = mol.GetNumHeavyAtoms()
return (mw, tpsa, logp, heavy)
def hash_str(x):
return int(hashlib.md5(x.encode('utf-8')).hexdigest(), 16)
def bin_by_edges(x, edges):
if isinstance(x, float) and math.isnan(x):
return -1
for i in range(1, len(edges)):
if x <= edges[i]:
return i-1
return len(edges) - 2
def hist_distance_L2(h1, h2):
keys = set(h1.keys()) | set(h2.keys())
return math.sqrt(sum((h1.get(k,0) - h2.get(k,0))**2 for k in keys))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--in-csv", required=True)
ap.add_argument("--out-dir", default=".")
ap.add_argument("--seed", type=int, default=2025)
ap.add_argument("--train-ratio", type=float, default=0.8)
ap.add_argument("--val-ratio", type=float, default=0.1)
ap.add_argument("--test-ratio", type=float, default=0.1)
ap.add_argument("--n_qed_bins", type=int, default=5)
ap.add_argument("--n_mw_bins", type=int, default=5)
# 新增:控制是否按 scaffold 大小降序分配(推荐 True
ap.add_argument("--largest-first", action="store_true", default=True)
args = ap.parse_args()
np.random.seed(args.seed)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
df = pd.read_csv(args.in_csv)
def col_like(name_candidates):
for nc in name_candidates:
idx = [i for i,c in enumerate(df.columns) if c.lower()==nc.lower()]
if idx:
return df.columns[idx[0]]
return None
col_smiles = col_like(["smiles", "canonical_smiles"])
col_selfies = col_like(["selfies"])
col_id = col_like(["id","molid","drug_id"])
col_qed = col_like(["qed"])
if col_smiles is None:
if col_selfies is None:
raise SystemExit("需要 SMILES 或 SELFIES 至少一个列。")
df["SMILES"] = df[col_selfies].map(to_smiles_from_selfies)
col_smiles = "SMILES"
else:
df["SMILES"] = df[col_smiles]
df["SMILES"] = df["SMILES"].map(canonicalize_smiles)
df = df[~df["SMILES"].isna()].copy()
df = df.drop_duplicates(subset=["SMILES"])
props = df["SMILES"].map(calc_basic_props)
df["MW"] = [p[0] for p in props]
df["TPSA"] = [p[1] for p in props]
df["LogP"] = [p[2] for p in props]
df["HeavyAtoms"] = [p[3] for p in props]
df["Scaffold"] = df["SMILES"].map(bemis_murcko_scaffold)
df["GroupKey"] = np.where(df["Scaffold"].isna(), df["SMILES"], df["Scaffold"])
if col_qed is None:
df["qed"] = np.nan
else:
df["qed"] = pd.to_numeric(df[col_qed], errors="coerce")
def compute_edges(series, n_bins):
ser = series.dropna()
if len(ser) < n_bins:
return np.linspace(ser.min() if len(ser)>0 else 0,
ser.max() if len(ser)>0 else 1,
n_bins+1)
qs = np.linspace(0,1,n_bins+1)
return np.quantile(ser, qs)
qed_edges = compute_edges(df["qed"], args.n_qed_bins)
mw_edges = compute_edges(df["MW"], args.n_mw_bins)
df["QED_bin"] = df["qed"].map(lambda x: bin_by_edges(x, qed_edges))
df["MW_bin"] = df["MW"].map(lambda x: bin_by_edges(x, mw_edges))
df["Strata"] = list(zip(df["QED_bin"], df["MW_bin"]))
group2idx = defaultdict(list)
for i, g in enumerate(df["GroupKey"]):
group2idx[g].append(i)
group_hist = {g: Counter(df.loc[idxs, "Strata"].tolist())
for g, idxs in group2idx.items()}
global_hist = Counter(df["Strata"].tolist())
total = len(df)
target_counts = {
"train": int(round(total * args.train_ratio)),
"val": int(round(total * args.val_ratio)),
"test": int(round(total * args.test_ratio)),
}
diff = total - sum(target_counts.values())
if diff != 0:
target_counts["train"] += diff
global_ratio = {
"train": args.train_ratio,
"val": args.val_ratio,
"test": args.test_ratio,
}
target_hist = {split: {k: v * r for k, v in global_hist.items()}
for split, r in global_ratio.items()}
splits = {"train": [], "val": [], "test": []}
split_counts = {"train": 0, "val": 0, "test": 0}
split_hist = {"train": Counter(), "val": Counter(), "test": Counter()}
scaffolds = list(group2idx.keys())
rng = np.random.RandomState(args.seed)
# 关键:按组大小降序(先放大的)
if args.largest_first:
scaffolds.sort(key=lambda g: len(group2idx[g]), reverse=True)
else:
rng.shuffle(scaffolds)
def remaining_capacity(split):
return target_counts[split] - split_counts[split]
for g in scaffolds:
idxs = group2idx[g]
g_size = len(idxs)
g_hist = group_hist[g]
# 计算每个 split 的剩余容量
caps = {s: remaining_capacity(s) for s in ["train","val","test"]}
# 优先考虑【仍有正容量】的 split
positive_splits = [s for s,c in caps.items() if c > 0]
# 如果都有容量<=0说明目标配额已满退化为“最小溢出优先”
candidate_splits = positive_splits if positive_splits else ["train","val","test"]
# 在候选中,综合“容量优先 + 分布对齐L2更小
best = None
for s in candidate_splits:
cap = caps[s]
proj_hist = split_hist[s] + g_hist
l2 = hist_distance_L2(proj_hist, target_hist[s])
# 排序键:
# 1) -max(cap, 0): 正容量越大越好;都负时此项=0
# 2) abs(min(cap, 0)): 负容量越小(溢出越少)越好
# 3) l2: 分布差越小越好
key = (-max(cap, 0), abs(min(cap, 0)), l2)
if (best is None) or (key < best[0]):
best = (key, s)
chosen = best[1]
splits[chosen].append(g)
split_counts[chosen] += g_size
split_hist[chosen] += g_hist
# 收集索引
idx_train, idx_val, idx_test = [], [], []
for g in splits["train"]:
idx_train.extend(group2idx[g])
for g in splits["val"]:
idx_val.extend(group2idx[g])
for g in splits["test"]:
idx_test.extend(group2idx[g])
df_train = df.iloc[sorted(idx_train)].copy()
df_val = df.iloc[sorted(idx_val)].copy()
df_test = df.iloc[sorted(idx_test)].copy()
def report(name, sub):
print(f"\n== {name} ==")
print(f"size = {len(sub)}")
h = Counter(sub["Strata"].tolist())
for k, v in h.most_common(8):
print(f" {k}: {v}")
print(" QED mean/std:", np.nanmean(sub["qed"]), np.nanstd(sub["qed"]))
print(" MW mean/std:", np.nanmean(sub["MW"]), np.nanstd(sub["MW"]))
print(f"Total: {len(df)} (targets: {target_counts})")
report("TRAIN", df_train)
report("VAL", df_val)
report("TEST", df_test)
out_train = out_dir / "split_train.csv"
out_val = out_dir / "split_val.csv"
out_test = out_dir / "split_test.csv"
df_train.to_csv(out_train, index=False)
df_val.to_csv(out_val, index=False)
df_test.to_csv(out_test, index=False)
edges = pd.DataFrame({
"qed_edges": list(qed_edges) + [np.nan]*(max(len(mw_edges),len(qed_edges))-len(qed_edges)),
"mw_edges": list(mw_edges) + [np.nan]*(max(len(mw_edges),len(qed_edges))-len(mw_edges)),
})
edges.to_csv(out_dir / "bin_edges_qed_mw.csv", index=False)
# 额外提示:打印容量偏差
print("\nFinal counts vs targets:")
for s in ["train","val","test"]:
print(f" {s}: {split_counts[s]} / {target_counts[s]} (delta={split_counts[s]-target_counts[s]})")
print(f"\nSaved:\n {out_train}\n {out_val}\n {out_test}\n {out_dir/'bin_edges_qed_mw.csv'}")
print("\n验证/测试阶段,仅使用 split_val / split_test 作为 feed-chemical。")
if __name__ == "__main__":
main()