重构项目结构并更新README.md
1. 重构目录结构: - 创建src/visualization模块用于存放可视化相关功能 - 移动script/visualize_csv_comparison.py到src/visualization/comparison.py - 创建src/visualization/__init__.py导出主要函数 - 整理script目录,按功能分类存放脚本文件 2. 更新README.md: - 添加CSV文件比较可视化部分 - 提供Python API和命令行使用方法说明 - 描述功能特点和使用示例 3. 更新模块引用: - 修正comparison.py中的模块引用路径 - 更新命令行帮助信息中的使用示例
This commit is contained in:
127
script/data_processing/add_macrocycle_columns.py
Normal file
127
script/data_processing/add_macrocycle_columns.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Add generated IDs and macrolactone ring size annotations to a CSV file."""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from rdkit import Chem
|
||||
|
||||
|
||||
@dataclass
|
||||
class MacrolactoneRingDetector:
|
||||
"""Detect macrolactone rings in SMILES strings via SMARTS patterns."""
|
||||
|
||||
min_size: int = 12
|
||||
max_size: int = 20
|
||||
patterns: Dict[int, Optional[Chem.Mol]] = field(init=False, default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.patterns = {
|
||||
size: Chem.MolFromSmarts(f"[r{size}]([#8][#6](=[#8]))")
|
||||
for size in range(self.min_size, self.max_size + 1)
|
||||
}
|
||||
|
||||
def ring_sizes(self, smiles: str) -> List[int]:
|
||||
"""Return a sorted list of macrolactone ring sizes present in the SMILES."""
|
||||
if not smiles:
|
||||
return []
|
||||
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return []
|
||||
|
||||
ring_atoms = mol.GetRingInfo().AtomRings()
|
||||
if not ring_atoms:
|
||||
return []
|
||||
|
||||
matched_rings: Set[Tuple[int, ...]] = set()
|
||||
matched_sizes: Set[int] = set()
|
||||
|
||||
for size, query in self.patterns.items():
|
||||
if query is None:
|
||||
continue
|
||||
|
||||
for match in mol.GetSubstructMatches(query, uniquify=True):
|
||||
ring = self._pick_ring(match, ring_atoms, size)
|
||||
if ring and ring not in matched_rings:
|
||||
matched_rings.add(ring)
|
||||
matched_sizes.add(size)
|
||||
|
||||
sizes = sorted(matched_sizes)
|
||||
|
||||
if len(sizes) > 1:
|
||||
warnings.warn(
|
||||
"Multiple macrolactone ring sizes detected",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
print(f"Multiple ring sizes {sizes} for SMILES: {smiles}")
|
||||
|
||||
return sizes
|
||||
|
||||
@staticmethod
|
||||
def _pick_ring(
|
||||
match: Tuple[int, ...], rings: Iterable[Tuple[int, ...]], expected_size: int
|
||||
) -> Optional[Tuple[int, ...]]:
|
||||
ring_atom = match[0]
|
||||
for ring in rings:
|
||||
if len(ring) == expected_size and ring_atom in ring:
|
||||
return tuple(sorted(ring))
|
||||
return None
|
||||
|
||||
|
||||
def add_columns(df: pd.DataFrame, detector: MacrolactoneRingDetector) -> pd.DataFrame:
|
||||
result = df.copy()
|
||||
result["generated_id"] = [f"D{index:06d}" for index in range(1, len(result) + 1)]
|
||||
|
||||
smiles_series = result["smiles"].fillna("") if "smiles" in result.columns else pd.Series(
|
||||
[""] * len(result), index=result.index
|
||||
)
|
||||
|
||||
def format_sizes(smiles: str) -> str:
|
||||
sizes = detector.ring_sizes(smiles)
|
||||
return ";".join(str(size) for size in sizes) if sizes else ""
|
||||
|
||||
result["macrocycle_ring_sizes"] = smiles_series.apply(format_sizes)
|
||||
return result
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("input_csv", type=pathlib.Path, help="Path to the source CSV file")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=pathlib.Path,
|
||||
default=None,
|
||||
help="Destination for the augmented CSV (default: <input>_with_macrocycles.csv)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if not args.input_csv.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {args.input_csv}")
|
||||
|
||||
output_path = args.output or args.input_csv.with_name(
|
||||
f"{args.input_csv.stem}_with_macrocycles.csv"
|
||||
)
|
||||
|
||||
if args.input_csv.resolve() == output_path.resolve():
|
||||
raise ValueError("Output path must differ from input path to avoid overwriting.")
|
||||
|
||||
df = pd.read_csv(args.input_csv)
|
||||
detector = MacrolactoneRingDetector()
|
||||
augmented = add_columns(df, detector)
|
||||
augmented.to_csv(output_path, index=False)
|
||||
|
||||
print(f"Wrote augmented data with {len(augmented)} rows to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user