feat(validation): add main validator class
This commit is contained in:
390
src/macro_lactone_toolkit/validation/validator.py
Normal file
390
src/macro_lactone_toolkit/validation/validator.py
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from rdkit import Chem
|
||||||
|
from rdkit.Chem import Descriptors
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from macro_lactone_toolkit import MacroLactoneAnalyzer
|
||||||
|
from macro_lactone_toolkit._core import (
|
||||||
|
build_numbering_result,
|
||||||
|
collect_side_chain_atoms,
|
||||||
|
find_macrolactone_candidates,
|
||||||
|
is_intrinsic_lactone_neighbor,
|
||||||
|
)
|
||||||
|
from macro_lactone_toolkit.validation.database import get_engine, get_session, init_database
|
||||||
|
from macro_lactone_toolkit.validation.isotope_utils import build_fragment_with_isotope
|
||||||
|
from macro_lactone_toolkit.validation.models import (
|
||||||
|
ClassificationType,
|
||||||
|
ParentMolecule,
|
||||||
|
ProcessingStatus,
|
||||||
|
RingNumbering,
|
||||||
|
SideChainFragment,
|
||||||
|
)
|
||||||
|
from macro_lactone_toolkit.validation.sampling import stratified_sample_by_ring_size
|
||||||
|
from macro_lactone_toolkit.validation.visualization_output import (
|
||||||
|
get_output_paths,
|
||||||
|
save_fragment_images,
|
||||||
|
save_numbered_molecule,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MacrolactoneValidator:
|
||||||
|
"""Validates macrolactone database with sampling and fragmentation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_dir: str | Path,
|
||||||
|
sample_ratio: float = 0.1,
|
||||||
|
smiles_col: str = "smiles",
|
||||||
|
id_col: str = "IDs",
|
||||||
|
):
|
||||||
|
self.output_dir = Path(output_dir)
|
||||||
|
self.sample_ratio = sample_ratio
|
||||||
|
self.smiles_col = smiles_col
|
||||||
|
self.id_col = id_col
|
||||||
|
|
||||||
|
self.analyzer = MacroLactoneAnalyzer()
|
||||||
|
|
||||||
|
# Initialize database
|
||||||
|
self.db_path = self.output_dir / "fragments.db"
|
||||||
|
self.engine = get_engine(self.db_path)
|
||||||
|
init_database(self.engine)
|
||||||
|
|
||||||
|
def run(self, input_csv: str | Path) -> dict:
|
||||||
|
"""Run validation on input CSV."""
|
||||||
|
# Load data
|
||||||
|
df = pd.read_csv(input_csv)
|
||||||
|
print(f"Loaded {len(df)} molecules from {input_csv}")
|
||||||
|
|
||||||
|
# Stratified sampling
|
||||||
|
print(f"Performing stratified sampling (ratio={self.sample_ratio})...")
|
||||||
|
sampled = stratified_sample_by_ring_size(df, self.sample_ratio, self.smiles_col)
|
||||||
|
print(f"Sampled {len(sampled)} molecules")
|
||||||
|
|
||||||
|
# Process each molecule
|
||||||
|
results = {"total": len(sampled), "success": 0, "failed": 0, "skipped": 0}
|
||||||
|
|
||||||
|
for idx, row in sampled.iterrows():
|
||||||
|
status = self._process_molecule(row)
|
||||||
|
results[status] += 1
|
||||||
|
if (idx + 1) % 100 == 0:
|
||||||
|
print(f"Processed {idx + 1}/{len(sampled)} molecules")
|
||||||
|
|
||||||
|
# Generate outputs
|
||||||
|
self._generate_readme()
|
||||||
|
self._generate_summary()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _process_molecule(self, row: pd.Series) -> str:
|
||||||
|
"""Process a single molecule. Returns status."""
|
||||||
|
source_id = str(row[self.id_col])
|
||||||
|
smiles = row[self.smiles_col]
|
||||||
|
name = row.get("molecule_pref_name", None)
|
||||||
|
|
||||||
|
# Classify
|
||||||
|
classification_result = self.analyzer.classify_macrocycle(smiles)
|
||||||
|
classification = classification_result.classification
|
||||||
|
ring_size = classification_result.ring_size
|
||||||
|
|
||||||
|
# Create parent record
|
||||||
|
parent = ParentMolecule(
|
||||||
|
source_id=source_id,
|
||||||
|
molecule_name=name,
|
||||||
|
smiles=smiles,
|
||||||
|
classification=classification,
|
||||||
|
ring_size=ring_size,
|
||||||
|
primary_reason_code=classification_result.primary_reason_code,
|
||||||
|
primary_reason_message=classification_result.primary_reason_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
with get_session(self.engine) as session:
|
||||||
|
session.add(parent)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(parent)
|
||||||
|
|
||||||
|
# Skip non-standard molecules
|
||||||
|
if classification != ClassificationType.STANDARD:
|
||||||
|
parent.processing_status = ProcessingStatus.SKIPPED
|
||||||
|
session.add(parent)
|
||||||
|
session.commit()
|
||||||
|
self._save_original_image(smiles, source_id, ring_size, classification)
|
||||||
|
return "skipped"
|
||||||
|
|
||||||
|
# Process standard macrolactone
|
||||||
|
try:
|
||||||
|
self._process_standard_macrolactone(session, parent, smiles)
|
||||||
|
return "success"
|
||||||
|
except Exception as e:
|
||||||
|
parent.processing_status = ProcessingStatus.FAILED
|
||||||
|
parent.error_message = str(e)
|
||||||
|
parent.processed_at = datetime.utcnow()
|
||||||
|
session.add(parent)
|
||||||
|
session.commit()
|
||||||
|
return "failed"
|
||||||
|
|
||||||
|
def _process_standard_macrolactone(self, session, parent: ParentMolecule, smiles: str):
|
||||||
|
"""Process a standard macrolactone."""
|
||||||
|
mol = Chem.MolFromSmiles(smiles)
|
||||||
|
if mol is None:
|
||||||
|
raise ValueError(f"Invalid SMILES: {smiles}")
|
||||||
|
|
||||||
|
# Find candidate
|
||||||
|
candidates = find_macrolactone_candidates(mol, ring_size=parent.ring_size)
|
||||||
|
if not candidates:
|
||||||
|
raise MacrolactoneDetectionError("No macrolactone candidate found")
|
||||||
|
if len(candidates) > 1:
|
||||||
|
# Pick the one matching ring_size
|
||||||
|
candidates = [c for c in candidates if c.ring_size == parent.ring_size]
|
||||||
|
if len(candidates) != 1:
|
||||||
|
raise ValueError("Ambiguous macrolactone candidates")
|
||||||
|
candidate = candidates[0]
|
||||||
|
|
||||||
|
# Get numbering
|
||||||
|
numbering = build_numbering_result(mol, candidate)
|
||||||
|
|
||||||
|
# Save numbering to database
|
||||||
|
numbering_record = RingNumbering(
|
||||||
|
parent_id=parent.id,
|
||||||
|
ring_size=numbering.ring_size,
|
||||||
|
carbonyl_carbon_idx=numbering.carbonyl_carbon_idx,
|
||||||
|
ester_oxygen_idx=numbering.ester_oxygen_idx,
|
||||||
|
position_to_atom=json.dumps(numbering.position_to_atom),
|
||||||
|
atom_to_position=json.dumps(numbering.atom_to_position),
|
||||||
|
)
|
||||||
|
session.add(numbering_record)
|
||||||
|
|
||||||
|
# Save numbered image
|
||||||
|
paths = get_output_paths(
|
||||||
|
self.output_dir, parent.source_id, parent.ring_size, "standard_macrolactone"
|
||||||
|
)
|
||||||
|
image_path = save_numbered_molecule(smiles, paths["numbered_image"], parent.ring_size)
|
||||||
|
if image_path:
|
||||||
|
parent.numbered_image_path = str(image_path.relative_to(self.output_dir))
|
||||||
|
|
||||||
|
# Fragment side chains
|
||||||
|
ring_atom_set = set(numbering.ring_atoms)
|
||||||
|
fragments = []
|
||||||
|
fragment_idx = 0
|
||||||
|
|
||||||
|
for position, ring_atom_idx in numbering.position_to_atom.items():
|
||||||
|
ring_atom = mol.GetAtomWithIdx(ring_atom_idx)
|
||||||
|
|
||||||
|
for neighbor in ring_atom.GetNeighbors():
|
||||||
|
neighbor_idx = neighbor.GetIdx()
|
||||||
|
|
||||||
|
# Skip ring atoms and intrinsic lactone neighbors
|
||||||
|
if neighbor_idx in ring_atom_set:
|
||||||
|
continue
|
||||||
|
if is_intrinsic_lactone_neighbor(mol, candidate, ring_atom_idx, neighbor_idx):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Collect side chain atoms
|
||||||
|
side_chain_atoms = collect_side_chain_atoms(mol, neighbor_idx, ring_atom_set)
|
||||||
|
if not side_chain_atoms:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build fragment with isotope tagging
|
||||||
|
labeled_smiles, plain_smiles, bond_type = build_fragment_with_isotope(
|
||||||
|
mol, side_chain_atoms, neighbor_idx, ring_atom_idx, int(position)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate properties
|
||||||
|
plain_mol = Chem.MolFromSmiles(plain_smiles)
|
||||||
|
if plain_mol is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
atom_count = sum(1 for a in plain_mol.GetAtoms() if a.GetAtomicNum() != 0)
|
||||||
|
heavy_atom_count = sum(1 for a in plain_mol.GetAtoms() if a.GetAtomicNum() not in [0, 1])
|
||||||
|
mw = Descriptors.MolWt(plain_mol)
|
||||||
|
|
||||||
|
# Create fragment record
|
||||||
|
fragment = SideChainFragment(
|
||||||
|
parent_id=parent.id,
|
||||||
|
fragment_id=f"{parent.source_id}_frag_{fragment_idx}",
|
||||||
|
cleavage_position=int(position),
|
||||||
|
attachment_atom_idx=ring_atom_idx,
|
||||||
|
attachment_atom_symbol=ring_atom.GetSymbol(),
|
||||||
|
fragment_smiles_labeled=labeled_smiles,
|
||||||
|
fragment_smiles_plain=plain_smiles,
|
||||||
|
dummy_isotope=int(position),
|
||||||
|
atom_count=atom_count,
|
||||||
|
heavy_atom_count=heavy_atom_count,
|
||||||
|
molecular_weight=round(mw, 4),
|
||||||
|
original_bond_type=bond_type,
|
||||||
|
)
|
||||||
|
session.add(fragment)
|
||||||
|
fragments.append(fragment)
|
||||||
|
fragment_idx += 1
|
||||||
|
|
||||||
|
# Save fragment images
|
||||||
|
if fragments and paths["sidechains_dir"]:
|
||||||
|
image_paths = save_fragment_images(fragments, paths["sidechains_dir"], parent.source_id)
|
||||||
|
for frag, img_path in zip(fragments, image_paths):
|
||||||
|
frag.image_path = img_path
|
||||||
|
session.add(frag)
|
||||||
|
|
||||||
|
# Update parent record
|
||||||
|
parent.processing_status = ProcessingStatus.SUCCESS
|
||||||
|
parent.num_sidechains = len(fragments)
|
||||||
|
parent.cleavage_positions = json.dumps([f.cleavage_position for f in fragments])
|
||||||
|
parent.processed_at = datetime.utcnow()
|
||||||
|
session.add(parent)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def _save_original_image(self, smiles: str, source_id: str, ring_size: int, classification: str):
|
||||||
|
"""Save original image for non-standard molecules."""
|
||||||
|
paths = get_output_paths(self.output_dir, source_id, ring_size, classification)
|
||||||
|
try:
|
||||||
|
from rdkit.Chem import Draw
|
||||||
|
|
||||||
|
mol = Chem.MolFromSmiles(smiles)
|
||||||
|
if mol:
|
||||||
|
Draw.MolToFile(mol, str(paths["numbered_image"]), size=(400, 400))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _generate_readme(self):
|
||||||
|
"""Generate README explaining output structure."""
|
||||||
|
readme_content = """# MacrolactoneDB Validation Output
|
||||||
|
|
||||||
|
This directory contains validation results for MacrolactoneDB 12-20 membered rings.
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
validation_output/
|
||||||
|
├── README.md # This file
|
||||||
|
├── fragments.db # SQLite database with all data
|
||||||
|
├── summary.csv # Summary of all processed molecules
|
||||||
|
├── summary_statistics.json # Statistical summary
|
||||||
|
│
|
||||||
|
├── ring_size_12/ # 12-membered rings
|
||||||
|
├── ring_size_13/ # 13-membered rings
|
||||||
|
...
|
||||||
|
└── ring_size_20/ # 20-membered rings
|
||||||
|
├── molecules.csv # Molecules in this ring size
|
||||||
|
├── standard/ # Standard macrolactones
|
||||||
|
│ ├── numbered/ # Numbered ring images
|
||||||
|
│ │ └── {id}_numbered.png
|
||||||
|
│ └── sidechains/ # Fragment images
|
||||||
|
│ └── {id}/
|
||||||
|
│ └── {id}_frag_{n}_pos{pos}.png
|
||||||
|
├── non_standard/ # Non-standard macrocycles
|
||||||
|
│ └── original/
|
||||||
|
│ └── {id}_original.png
|
||||||
|
└── rejected/ # Not macrolactones
|
||||||
|
└── original/
|
||||||
|
└── {id}_original.png
|
||||||
|
```
|
||||||
|
|
||||||
|
## Database Schema
|
||||||
|
|
||||||
|
### Tables
|
||||||
|
|
||||||
|
- **parent_molecules**: Original molecule information
|
||||||
|
- **ring_numberings**: Ring atom numbering details
|
||||||
|
- **side_chain_fragments**: Fragmentation results with isotope tags
|
||||||
|
- **validation_results**: Manual validation records
|
||||||
|
|
||||||
|
### Key Fields
|
||||||
|
|
||||||
|
- `classification`: standard_macrolactone | non_standard_macrocycle | not_macrolactone
|
||||||
|
- `dummy_isotope`: Cleavage position stored as isotope value for reconstruction
|
||||||
|
- `cleavage_position`: Position on ring where side chain was attached
|
||||||
|
|
||||||
|
## Ring Numbering Convention
|
||||||
|
|
||||||
|
1. Position 1 = Lactone carbonyl carbon (C=O)
|
||||||
|
2. Position 2 = Ester oxygen (-O-)
|
||||||
|
3. Positions 3-N = Sequential around ring
|
||||||
|
|
||||||
|
## Isotope Tagging
|
||||||
|
|
||||||
|
Fragments use isotope values to mark cleavage position:
|
||||||
|
- `[5*]CCO` = Fragment from position 5, dummy atom has isotope=5
|
||||||
|
- This enables precise reconstruction during reassembly
|
||||||
|
|
||||||
|
## CSV Columns
|
||||||
|
|
||||||
|
### summary.csv
|
||||||
|
|
||||||
|
- `source_id`: Original molecule ID from MacrolactoneDB
|
||||||
|
- `classification`: Classification result
|
||||||
|
- `ring_size`: Detected ring size (12-20)
|
||||||
|
- `num_sidechains`: Number of side chains detected
|
||||||
|
- `cleavage_positions`: JSON array of cleavage positions
|
||||||
|
- `processing_status`: pending | success | failed | skipped
|
||||||
|
|
||||||
|
## Querying the Database
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# List tables
|
||||||
|
sqlite3 fragments.db ".tables"
|
||||||
|
|
||||||
|
# Get standard macrolactones with fragments
|
||||||
|
sqlite3 fragments.db "SELECT * FROM parent_molecules WHERE classification='standard_macrolactone' LIMIT 5;"
|
||||||
|
|
||||||
|
# Get fragments for a specific molecule
|
||||||
|
sqlite3 fragments.db "SELECT * FROM side_chain_fragments WHERE parent_id=1;"
|
||||||
|
|
||||||
|
# Count by ring size
|
||||||
|
sqlite3 fragments.db "SELECT ring_size, COUNT(*) FROM parent_molecules GROUP BY ring_size;"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
readme_path = self.output_dir / "README.md"
|
||||||
|
readme_path.write_text(readme_content)
|
||||||
|
|
||||||
|
def _generate_summary(self):
|
||||||
|
"""Generate summary CSV and statistics."""
|
||||||
|
with get_session(self.engine) as session:
|
||||||
|
# Query all parents
|
||||||
|
statement = select(ParentMolecule)
|
||||||
|
parents = session.exec(statement).all()
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
data = []
|
||||||
|
for p in parents:
|
||||||
|
data.append({
|
||||||
|
"id": p.id,
|
||||||
|
"source_id": p.source_id,
|
||||||
|
"molecule_name": p.molecule_name,
|
||||||
|
"smiles": p.smiles,
|
||||||
|
"classification": p.classification,
|
||||||
|
"ring_size": p.ring_size,
|
||||||
|
"primary_reason_code": p.primary_reason_code,
|
||||||
|
"primary_reason_message": p.primary_reason_message,
|
||||||
|
"processing_status": p.processing_status,
|
||||||
|
"error_message": p.error_message,
|
||||||
|
"num_sidechains": p.num_sidechains,
|
||||||
|
"cleavage_positions": p.cleavage_positions,
|
||||||
|
"numbered_image_path": p.numbered_image_path,
|
||||||
|
"processed_at": p.processed_at,
|
||||||
|
})
|
||||||
|
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
df.to_csv(self.output_dir / "summary.csv", index=False)
|
||||||
|
|
||||||
|
# Generate statistics
|
||||||
|
stats = {
|
||||||
|
"total_molecules": len(parents),
|
||||||
|
"by_classification": df["classification"].value_counts().to_dict() if len(df) > 0 else {},
|
||||||
|
"by_ring_size": df[df["ring_size"].notna()]["ring_size"].value_counts().to_dict() if len(df) > 0 else {},
|
||||||
|
"by_status": df["processing_status"].value_counts().to_dict() if len(df) > 0 else {},
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.output_dir / "summary_statistics.json", "w") as f:
|
||||||
|
json.dump(stats, f, indent=2, default=str)
|
||||||
|
|
||||||
|
print(f"\nSummary saved to {self.output_dir / 'summary.csv'}")
|
||||||
|
print(f"Statistics: {stats}")
|
||||||
|
|
||||||
|
|
||||||
|
class MacrolactoneDetectionError(Exception):
|
||||||
|
"""Raised when macrolactone detection fails."""
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user