feat(validation): add main validator class

This commit is contained in:
2026-03-19 10:31:08 +08:00
parent 4e869bb693
commit e3c08ad8c0

View File

@@ -0,0 +1,390 @@
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Descriptors
from sqlmodel import select
from macro_lactone_toolkit import MacroLactoneAnalyzer
from macro_lactone_toolkit._core import (
build_numbering_result,
collect_side_chain_atoms,
find_macrolactone_candidates,
is_intrinsic_lactone_neighbor,
)
from macro_lactone_toolkit.validation.database import get_engine, get_session, init_database
from macro_lactone_toolkit.validation.isotope_utils import build_fragment_with_isotope
from macro_lactone_toolkit.validation.models import (
ClassificationType,
ParentMolecule,
ProcessingStatus,
RingNumbering,
SideChainFragment,
)
from macro_lactone_toolkit.validation.sampling import stratified_sample_by_ring_size
from macro_lactone_toolkit.validation.visualization_output import (
get_output_paths,
save_fragment_images,
save_numbered_molecule,
)
class MacrolactoneValidator:
"""Validates macrolactone database with sampling and fragmentation."""
def __init__(
self,
output_dir: str | Path,
sample_ratio: float = 0.1,
smiles_col: str = "smiles",
id_col: str = "IDs",
):
self.output_dir = Path(output_dir)
self.sample_ratio = sample_ratio
self.smiles_col = smiles_col
self.id_col = id_col
self.analyzer = MacroLactoneAnalyzer()
# Initialize database
self.db_path = self.output_dir / "fragments.db"
self.engine = get_engine(self.db_path)
init_database(self.engine)
def run(self, input_csv: str | Path) -> dict:
"""Run validation on input CSV."""
# Load data
df = pd.read_csv(input_csv)
print(f"Loaded {len(df)} molecules from {input_csv}")
# Stratified sampling
print(f"Performing stratified sampling (ratio={self.sample_ratio})...")
sampled = stratified_sample_by_ring_size(df, self.sample_ratio, self.smiles_col)
print(f"Sampled {len(sampled)} molecules")
# Process each molecule
results = {"total": len(sampled), "success": 0, "failed": 0, "skipped": 0}
for idx, row in sampled.iterrows():
status = self._process_molecule(row)
results[status] += 1
if (idx + 1) % 100 == 0:
print(f"Processed {idx + 1}/{len(sampled)} molecules")
# Generate outputs
self._generate_readme()
self._generate_summary()
return results
def _process_molecule(self, row: pd.Series) -> str:
"""Process a single molecule. Returns status."""
source_id = str(row[self.id_col])
smiles = row[self.smiles_col]
name = row.get("molecule_pref_name", None)
# Classify
classification_result = self.analyzer.classify_macrocycle(smiles)
classification = classification_result.classification
ring_size = classification_result.ring_size
# Create parent record
parent = ParentMolecule(
source_id=source_id,
molecule_name=name,
smiles=smiles,
classification=classification,
ring_size=ring_size,
primary_reason_code=classification_result.primary_reason_code,
primary_reason_message=classification_result.primary_reason_message,
)
with get_session(self.engine) as session:
session.add(parent)
session.commit()
session.refresh(parent)
# Skip non-standard molecules
if classification != ClassificationType.STANDARD:
parent.processing_status = ProcessingStatus.SKIPPED
session.add(parent)
session.commit()
self._save_original_image(smiles, source_id, ring_size, classification)
return "skipped"
# Process standard macrolactone
try:
self._process_standard_macrolactone(session, parent, smiles)
return "success"
except Exception as e:
parent.processing_status = ProcessingStatus.FAILED
parent.error_message = str(e)
parent.processed_at = datetime.utcnow()
session.add(parent)
session.commit()
return "failed"
def _process_standard_macrolactone(self, session, parent: ParentMolecule, smiles: str):
"""Process a standard macrolactone."""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError(f"Invalid SMILES: {smiles}")
# Find candidate
candidates = find_macrolactone_candidates(mol, ring_size=parent.ring_size)
if not candidates:
raise MacrolactoneDetectionError("No macrolactone candidate found")
if len(candidates) > 1:
# Pick the one matching ring_size
candidates = [c for c in candidates if c.ring_size == parent.ring_size]
if len(candidates) != 1:
raise ValueError("Ambiguous macrolactone candidates")
candidate = candidates[0]
# Get numbering
numbering = build_numbering_result(mol, candidate)
# Save numbering to database
numbering_record = RingNumbering(
parent_id=parent.id,
ring_size=numbering.ring_size,
carbonyl_carbon_idx=numbering.carbonyl_carbon_idx,
ester_oxygen_idx=numbering.ester_oxygen_idx,
position_to_atom=json.dumps(numbering.position_to_atom),
atom_to_position=json.dumps(numbering.atom_to_position),
)
session.add(numbering_record)
# Save numbered image
paths = get_output_paths(
self.output_dir, parent.source_id, parent.ring_size, "standard_macrolactone"
)
image_path = save_numbered_molecule(smiles, paths["numbered_image"], parent.ring_size)
if image_path:
parent.numbered_image_path = str(image_path.relative_to(self.output_dir))
# Fragment side chains
ring_atom_set = set(numbering.ring_atoms)
fragments = []
fragment_idx = 0
for position, ring_atom_idx in numbering.position_to_atom.items():
ring_atom = mol.GetAtomWithIdx(ring_atom_idx)
for neighbor in ring_atom.GetNeighbors():
neighbor_idx = neighbor.GetIdx()
# Skip ring atoms and intrinsic lactone neighbors
if neighbor_idx in ring_atom_set:
continue
if is_intrinsic_lactone_neighbor(mol, candidate, ring_atom_idx, neighbor_idx):
continue
# Collect side chain atoms
side_chain_atoms = collect_side_chain_atoms(mol, neighbor_idx, ring_atom_set)
if not side_chain_atoms:
continue
# Build fragment with isotope tagging
labeled_smiles, plain_smiles, bond_type = build_fragment_with_isotope(
mol, side_chain_atoms, neighbor_idx, ring_atom_idx, int(position)
)
# Calculate properties
plain_mol = Chem.MolFromSmiles(plain_smiles)
if plain_mol is None:
continue
atom_count = sum(1 for a in plain_mol.GetAtoms() if a.GetAtomicNum() != 0)
heavy_atom_count = sum(1 for a in plain_mol.GetAtoms() if a.GetAtomicNum() not in [0, 1])
mw = Descriptors.MolWt(plain_mol)
# Create fragment record
fragment = SideChainFragment(
parent_id=parent.id,
fragment_id=f"{parent.source_id}_frag_{fragment_idx}",
cleavage_position=int(position),
attachment_atom_idx=ring_atom_idx,
attachment_atom_symbol=ring_atom.GetSymbol(),
fragment_smiles_labeled=labeled_smiles,
fragment_smiles_plain=plain_smiles,
dummy_isotope=int(position),
atom_count=atom_count,
heavy_atom_count=heavy_atom_count,
molecular_weight=round(mw, 4),
original_bond_type=bond_type,
)
session.add(fragment)
fragments.append(fragment)
fragment_idx += 1
# Save fragment images
if fragments and paths["sidechains_dir"]:
image_paths = save_fragment_images(fragments, paths["sidechains_dir"], parent.source_id)
for frag, img_path in zip(fragments, image_paths):
frag.image_path = img_path
session.add(frag)
# Update parent record
parent.processing_status = ProcessingStatus.SUCCESS
parent.num_sidechains = len(fragments)
parent.cleavage_positions = json.dumps([f.cleavage_position for f in fragments])
parent.processed_at = datetime.utcnow()
session.add(parent)
session.commit()
def _save_original_image(self, smiles: str, source_id: str, ring_size: int, classification: str):
"""Save original image for non-standard molecules."""
paths = get_output_paths(self.output_dir, source_id, ring_size, classification)
try:
from rdkit.Chem import Draw
mol = Chem.MolFromSmiles(smiles)
if mol:
Draw.MolToFile(mol, str(paths["numbered_image"]), size=(400, 400))
except Exception:
pass
def _generate_readme(self):
"""Generate README explaining output structure."""
readme_content = """# MacrolactoneDB Validation Output
This directory contains validation results for MacrolactoneDB 12-20 membered rings.
## Directory Structure
```
validation_output/
├── README.md # This file
├── fragments.db # SQLite database with all data
├── summary.csv # Summary of all processed molecules
├── summary_statistics.json # Statistical summary
├── ring_size_12/ # 12-membered rings
├── ring_size_13/ # 13-membered rings
...
└── ring_size_20/ # 20-membered rings
├── molecules.csv # Molecules in this ring size
├── standard/ # Standard macrolactones
│ ├── numbered/ # Numbered ring images
│ │ └── {id}_numbered.png
│ └── sidechains/ # Fragment images
│ └── {id}/
│ └── {id}_frag_{n}_pos{pos}.png
├── non_standard/ # Non-standard macrocycles
│ └── original/
│ └── {id}_original.png
└── rejected/ # Not macrolactones
└── original/
└── {id}_original.png
```
## Database Schema
### Tables
- **parent_molecules**: Original molecule information
- **ring_numberings**: Ring atom numbering details
- **side_chain_fragments**: Fragmentation results with isotope tags
- **validation_results**: Manual validation records
### Key Fields
- `classification`: standard_macrolactone | non_standard_macrocycle | not_macrolactone
- `dummy_isotope`: Cleavage position stored as isotope value for reconstruction
- `cleavage_position`: Position on ring where side chain was attached
## Ring Numbering Convention
1. Position 1 = Lactone carbonyl carbon (C=O)
2. Position 2 = Ester oxygen (-O-)
3. Positions 3-N = Sequential around ring
## Isotope Tagging
Fragments use isotope values to mark cleavage position:
- `[5*]CCO` = Fragment from position 5, dummy atom has isotope=5
- This enables precise reconstruction during reassembly
## CSV Columns
### summary.csv
- `source_id`: Original molecule ID from MacrolactoneDB
- `classification`: Classification result
- `ring_size`: Detected ring size (12-20)
- `num_sidechains`: Number of side chains detected
- `cleavage_positions`: JSON array of cleavage positions
- `processing_status`: pending | success | failed | skipped
## Querying the Database
```bash
# List tables
sqlite3 fragments.db ".tables"
# Get standard macrolactones with fragments
sqlite3 fragments.db "SELECT * FROM parent_molecules WHERE classification='standard_macrolactone' LIMIT 5;"
# Get fragments for a specific molecule
sqlite3 fragments.db "SELECT * FROM side_chain_fragments WHERE parent_id=1;"
# Count by ring size
sqlite3 fragments.db "SELECT ring_size, COUNT(*) FROM parent_molecules GROUP BY ring_size;"
```
"""
readme_path = self.output_dir / "README.md"
readme_path.write_text(readme_content)
def _generate_summary(self):
"""Generate summary CSV and statistics."""
with get_session(self.engine) as session:
# Query all parents
statement = select(ParentMolecule)
parents = session.exec(statement).all()
# Convert to DataFrame
data = []
for p in parents:
data.append({
"id": p.id,
"source_id": p.source_id,
"molecule_name": p.molecule_name,
"smiles": p.smiles,
"classification": p.classification,
"ring_size": p.ring_size,
"primary_reason_code": p.primary_reason_code,
"primary_reason_message": p.primary_reason_message,
"processing_status": p.processing_status,
"error_message": p.error_message,
"num_sidechains": p.num_sidechains,
"cleavage_positions": p.cleavage_positions,
"numbered_image_path": p.numbered_image_path,
"processed_at": p.processed_at,
})
df = pd.DataFrame(data)
df.to_csv(self.output_dir / "summary.csv", index=False)
# Generate statistics
stats = {
"total_molecules": len(parents),
"by_classification": df["classification"].value_counts().to_dict() if len(df) > 0 else {},
"by_ring_size": df[df["ring_size"].notna()]["ring_size"].value_counts().to_dict() if len(df) > 0 else {},
"by_status": df["processing_status"].value_counts().to_dict() if len(df) > 0 else {},
}
with open(self.output_dir / "summary_statistics.json", "w") as f:
json.dump(stats, f, indent=2, default=str)
print(f"\nSummary saved to {self.output_dir / 'summary.csv'}")
print(f"Statistics: {stats}")
class MacrolactoneDetectionError(Exception):
"""Raised when macrolactone detection fails."""
pass