feat(toolkit): add classification and migration
Implement the standard/non-standard/not-macrolactone classification layer and integrate it into analyzer, fragmenter, and CLI outputs. Port the remaining legacy package capabilities into new visualization and workflow modules, restore batch/statistics/SDF scripts on top of the flat CSV workflow, and update active docs to the new package API.
This commit is contained in:
@@ -7,19 +7,45 @@ from .errors import (
|
||||
RingNumberingError,
|
||||
)
|
||||
from .fragmenter import MacrolactoneFragmenter
|
||||
from .models import FragmentationResult, RingNumberingResult, SideChainFragment
|
||||
from .models import (
|
||||
FragmentationResult,
|
||||
MacrocycleClassificationResult,
|
||||
RingNumberingResult,
|
||||
SideChainFragment,
|
||||
)
|
||||
from .visualization import (
|
||||
fragment_svg,
|
||||
numbered_molecule_svg,
|
||||
save_fragment_png,
|
||||
save_numbered_molecule_png,
|
||||
)
|
||||
from .workflows import (
|
||||
export_numbered_macrolactone_csv,
|
||||
fragment_csv,
|
||||
results_to_dataframe,
|
||||
write_result_json,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AmbiguousMacrolactoneError",
|
||||
"FragmentationError",
|
||||
"FragmentationResult",
|
||||
"fragment_csv",
|
||||
"fragment_svg",
|
||||
"MacroLactoneAnalyzer",
|
||||
"MacrolactoneDetectionError",
|
||||
"MacrolactoneError",
|
||||
"MacrolactoneFragmenter",
|
||||
"MacrocycleClassificationResult",
|
||||
"numbered_molecule_svg",
|
||||
"RingNumberingError",
|
||||
"RingNumberingResult",
|
||||
"results_to_dataframe",
|
||||
"save_fragment_png",
|
||||
"save_numbered_molecule_png",
|
||||
"SideChainFragment",
|
||||
"export_numbered_macrolactone_csv",
|
||||
"write_result_json",
|
||||
]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
@@ -6,11 +6,21 @@ from typing import Iterable
|
||||
|
||||
from rdkit import Chem
|
||||
|
||||
from .errors import MacrolactoneDetectionError, RingNumberingError
|
||||
from .models import RingNumberingResult
|
||||
from .errors import AmbiguousMacrolactoneError, MacrolactoneDetectionError, RingNumberingError
|
||||
from .models import MacrocycleClassificationResult, RingNumberingResult
|
||||
|
||||
|
||||
VALID_RING_SIZES = tuple(range(12, 21))
|
||||
REASON_MESSAGES = {
|
||||
"contains_non_carbon_ring_atoms_outside_positions_1_2": (
|
||||
"Ring positions 3..N contain non-carbon atoms."
|
||||
),
|
||||
"multiple_overlapping_macrocycle_candidates": (
|
||||
"Overlapping macrolactone candidate rings were detected."
|
||||
),
|
||||
"no_lactone_ring_in_12_to_20_range": "No 12-20 membered lactone ring was detected.",
|
||||
"requested_ring_size_not_found": "The requested ring size was not detected as a lactone ring.",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -73,6 +83,62 @@ def find_macrolactone_candidates(
|
||||
)
|
||||
|
||||
|
||||
def classify_macrolactone(
|
||||
mol: Chem.Mol,
|
||||
smiles: str,
|
||||
ring_size: int | None = None,
|
||||
) -> MacrocycleClassificationResult:
|
||||
candidates = find_macrolactone_candidates(mol, ring_size=ring_size)
|
||||
candidate_ring_sizes = sorted({candidate.ring_size for candidate in candidates})
|
||||
|
||||
if not candidates:
|
||||
reason_code = (
|
||||
"requested_ring_size_not_found"
|
||||
if ring_size is not None
|
||||
else "no_lactone_ring_in_12_to_20_range"
|
||||
)
|
||||
return _build_classification_result(
|
||||
smiles=smiles,
|
||||
classification="not_macrolactone",
|
||||
ring_size=None,
|
||||
candidate_ring_sizes=[],
|
||||
reason_codes=[reason_code],
|
||||
)
|
||||
|
||||
if _has_overlapping_candidates(candidates):
|
||||
return _build_classification_result(
|
||||
smiles=smiles,
|
||||
classification="non_standard_macrocycle",
|
||||
ring_size=candidate_ring_sizes[0] if len(candidate_ring_sizes) == 1 else None,
|
||||
candidate_ring_sizes=candidate_ring_sizes,
|
||||
reason_codes=["multiple_overlapping_macrocycle_candidates"],
|
||||
)
|
||||
|
||||
if len(candidates) > 1 or len(candidate_ring_sizes) > 1:
|
||||
raise AmbiguousMacrolactoneError(
|
||||
"Multiple valid macrolactone candidates were detected. Pass ring_size explicitly."
|
||||
)
|
||||
|
||||
candidate = candidates[0]
|
||||
numbering = build_numbering_result(mol, candidate)
|
||||
if _contains_non_carbon_atoms_outside_positions_1_2(mol, numbering):
|
||||
return _build_classification_result(
|
||||
smiles=smiles,
|
||||
classification="non_standard_macrocycle",
|
||||
ring_size=candidate.ring_size,
|
||||
candidate_ring_sizes=candidate_ring_sizes,
|
||||
reason_codes=["contains_non_carbon_ring_atoms_outside_positions_1_2"],
|
||||
)
|
||||
|
||||
return _build_classification_result(
|
||||
smiles=smiles,
|
||||
classification="standard_macrolactone",
|
||||
ring_size=candidate.ring_size,
|
||||
candidate_ring_sizes=candidate_ring_sizes,
|
||||
reason_codes=[],
|
||||
)
|
||||
|
||||
|
||||
def build_numbering_result(mol: Chem.Mol, candidate: DetectedMacrolactone) -> RingNumberingResult:
|
||||
ring_atoms = list(candidate.ring_atoms)
|
||||
ring_atom_set = set(ring_atoms)
|
||||
@@ -120,6 +186,66 @@ def build_numbering_result(mol: Chem.Mol, candidate: DetectedMacrolactone) -> Ri
|
||||
)
|
||||
|
||||
|
||||
def _build_classification_result(
|
||||
smiles: str,
|
||||
classification: str,
|
||||
ring_size: int | None,
|
||||
candidate_ring_sizes: list[int],
|
||||
reason_codes: list[str],
|
||||
) -> MacrocycleClassificationResult:
|
||||
reason_messages = [REASON_MESSAGES[reason_code] for reason_code in reason_codes]
|
||||
return MacrocycleClassificationResult(
|
||||
smiles=smiles,
|
||||
classification=classification,
|
||||
ring_size=ring_size,
|
||||
primary_reason_code=reason_codes[0] if reason_codes else None,
|
||||
primary_reason_message=reason_messages[0] if reason_messages else None,
|
||||
all_reason_codes=list(reason_codes),
|
||||
all_reason_messages=reason_messages,
|
||||
candidate_ring_sizes=list(candidate_ring_sizes),
|
||||
)
|
||||
|
||||
|
||||
def _contains_non_carbon_atoms_outside_positions_1_2(
|
||||
mol: Chem.Mol,
|
||||
numbering: RingNumberingResult,
|
||||
) -> bool:
|
||||
for position in range(3, numbering.ring_size + 1):
|
||||
atom_idx = numbering.position_to_atom[position]
|
||||
if mol.GetAtomWithIdx(atom_idx).GetAtomicNum() != 6:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _has_overlapping_candidates(candidates: list[DetectedMacrolactone]) -> bool:
|
||||
ring_sets = [set(candidate.ring_atoms) for candidate in candidates]
|
||||
visited: set[int] = set()
|
||||
|
||||
for start_index in range(len(candidates)):
|
||||
if start_index in visited:
|
||||
continue
|
||||
|
||||
queue = deque([start_index])
|
||||
component_size = 0
|
||||
while queue:
|
||||
candidate_index = queue.popleft()
|
||||
if candidate_index in visited:
|
||||
continue
|
||||
visited.add(candidate_index)
|
||||
component_size += 1
|
||||
|
||||
for neighbor_index in range(len(candidates)):
|
||||
if neighbor_index == candidate_index or neighbor_index in visited:
|
||||
continue
|
||||
if ring_sets[candidate_index].intersection(ring_sets[neighbor_index]):
|
||||
queue.append(neighbor_index)
|
||||
|
||||
if component_size > 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def collect_side_chain_atoms(
|
||||
mol: Chem.Mol,
|
||||
start_atom_idx: int,
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pandas as pd
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import Crippen, Descriptors, Lipinski, QED
|
||||
|
||||
from ._core import ensure_mol, find_macrolactone_candidates
|
||||
from ._core import classify_macrolactone, ensure_mol, find_macrolactone_candidates
|
||||
from .models import MacrocycleClassificationResult
|
||||
|
||||
|
||||
class MacroLactoneAnalyzer:
|
||||
@@ -13,15 +16,108 @@ class MacroLactoneAnalyzer:
|
||||
candidates = find_macrolactone_candidates(mol)
|
||||
return sorted({candidate.ring_size for candidate in candidates})
|
||||
|
||||
def analyze_molecule(self, mol_input: str | Chem.Mol) -> dict:
|
||||
def classify_macrocycle(
|
||||
self,
|
||||
mol_input: str | Chem.Mol,
|
||||
ring_size: int | None = None,
|
||||
) -> MacrocycleClassificationResult:
|
||||
mol, smiles = ensure_mol(mol_input)
|
||||
candidates = find_macrolactone_candidates(mol)
|
||||
valid_ring_sizes = sorted({candidate.ring_size for candidate in candidates})
|
||||
is_ambiguous = len(valid_ring_sizes) > 1 or len(candidates) > 1
|
||||
return {
|
||||
"smiles": smiles,
|
||||
"valid_ring_sizes": valid_ring_sizes,
|
||||
"candidate_count": len(candidates),
|
||||
"is_ambiguous": is_ambiguous,
|
||||
"selected_ring_size": valid_ring_sizes[0] if len(valid_ring_sizes) == 1 and len(candidates) == 1 else None,
|
||||
return classify_macrolactone(mol, smiles=smiles, ring_size=ring_size)
|
||||
|
||||
def analyze_molecule(
|
||||
self,
|
||||
mol_input: str | Chem.Mol,
|
||||
ring_size: int | None = None,
|
||||
) -> dict:
|
||||
return self.classify_macrocycle(mol_input, ring_size=ring_size).to_dict()
|
||||
|
||||
def analyze_many(
|
||||
self,
|
||||
smiles_list: list[str],
|
||||
ring_range: range = range(12, 21),
|
||||
) -> dict:
|
||||
classification_counts = {
|
||||
"standard_macrolactone": 0,
|
||||
"non_standard_macrocycle": 0,
|
||||
"not_macrolactone": 0,
|
||||
}
|
||||
ring_size_counts = {ring_size: 0 for ring_size in ring_range}
|
||||
results: list[dict] = []
|
||||
|
||||
for smiles in smiles_list:
|
||||
classification = self.classify_macrocycle(smiles)
|
||||
classification_counts[classification.classification] += 1
|
||||
if (
|
||||
classification.classification == "standard_macrolactone"
|
||||
and classification.ring_size in ring_size_counts
|
||||
):
|
||||
ring_size_counts[classification.ring_size] += 1
|
||||
results.append(classification.to_dict())
|
||||
|
||||
return {
|
||||
"total": len(smiles_list),
|
||||
"classification_counts": classification_counts,
|
||||
"ring_size_counts": ring_size_counts,
|
||||
"results": results,
|
||||
}
|
||||
|
||||
def classify_dataframe(
|
||||
self,
|
||||
dataframe: pd.DataFrame,
|
||||
smiles_column: str = "smiles",
|
||||
id_column: str | None = None,
|
||||
ring_range: range = range(12, 21),
|
||||
) -> tuple[dict[int, pd.DataFrame], pd.DataFrame]:
|
||||
grouped_rows: dict[int, list[dict]] = {ring_size: [] for ring_size in ring_range}
|
||||
rejected_rows: list[dict] = []
|
||||
|
||||
for index, row in dataframe.iterrows():
|
||||
base_row = row.to_dict()
|
||||
if id_column is None and "id" not in base_row:
|
||||
base_row["id"] = f"row_{index}"
|
||||
|
||||
classification = self.classify_macrocycle(base_row[smiles_column])
|
||||
enriched_row = {
|
||||
**base_row,
|
||||
**classification.to_dict(),
|
||||
}
|
||||
if (
|
||||
classification.classification == "standard_macrolactone"
|
||||
and classification.ring_size in grouped_rows
|
||||
):
|
||||
grouped_rows[classification.ring_size].append(enriched_row)
|
||||
else:
|
||||
rejected_rows.append(enriched_row)
|
||||
|
||||
grouped_frames = {
|
||||
ring_size: pd.DataFrame(rows)
|
||||
for ring_size, rows in grouped_rows.items()
|
||||
if rows
|
||||
}
|
||||
return grouped_frames, pd.DataFrame(rejected_rows)
|
||||
|
||||
def match_dynamic_smarts(self, smiles: str, ring_size: int) -> list[int] | None:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
query = Chem.MolFromSmarts(f"[r{ring_size}]([#8][#6](=[#8]))")
|
||||
if query is None:
|
||||
return None
|
||||
matches = mol.GetSubstructMatches(query)
|
||||
return list(matches[0]) if matches else None
|
||||
|
||||
def calculate_properties(self, smiles: str) -> dict[str, float] | None:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
if mol is None:
|
||||
return None
|
||||
return {
|
||||
"molecular_weight": Descriptors.MolWt(mol),
|
||||
"logp": Crippen.MolLogP(mol),
|
||||
"qed": QED.qed(mol),
|
||||
"tpsa": Descriptors.TPSA(mol),
|
||||
"num_atoms": float(mol.GetNumAtoms()),
|
||||
"num_heavy_atoms": float(mol.GetNumHeavyAtoms()),
|
||||
"num_h_donors": float(Lipinski.NumHDonors(mol)),
|
||||
"num_h_acceptors": float(Lipinski.NumHAcceptors(mol)),
|
||||
"num_rotatable_bonds": float(Lipinski.NumRotatableBonds(mol)),
|
||||
}
|
||||
|
||||
@@ -48,14 +48,14 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
def run_analyze(args: argparse.Namespace) -> None:
|
||||
analyzer = MacroLactoneAnalyzer()
|
||||
if args.smiles:
|
||||
payload = analyzer.analyze_molecule(args.smiles)
|
||||
payload = analyzer.analyze_molecule(args.smiles, ring_size=args.ring_size)
|
||||
_write_output(payload, args.output)
|
||||
return
|
||||
|
||||
rows = _read_csv_rows(args.input, args.smiles_column, args.id_column)
|
||||
payload = []
|
||||
for row in rows:
|
||||
analysis = analyzer.analyze_molecule(row["smiles"])
|
||||
analysis = analyzer.analyze_molecule(row["smiles"], ring_size=args.ring_size)
|
||||
analysis["parent_id"] = row["parent_id"]
|
||||
payload.append(analysis)
|
||||
_write_output(payload, args.output)
|
||||
|
||||
@@ -101,11 +101,15 @@ class MacrolactoneFragmenter:
|
||||
)
|
||||
|
||||
def _select_candidate(self, mol: Chem.Mol):
|
||||
candidates = find_macrolactone_candidates(mol, ring_size=self.ring_size)
|
||||
if not candidates:
|
||||
requested = f"{self.ring_size}-membered " if self.ring_size is not None else ""
|
||||
raise MacrolactoneDetectionError(f"No valid {requested}macrolactone was detected.")
|
||||
classification = self.analyzer.classify_macrocycle(mol, ring_size=self.ring_size)
|
||||
if classification.classification != "standard_macrolactone":
|
||||
raise MacrolactoneDetectionError(
|
||||
"Macrolactone rejected: "
|
||||
f"classification={classification.classification} "
|
||||
f"primary_reason_code={classification.primary_reason_code}"
|
||||
)
|
||||
|
||||
candidates = find_macrolactone_candidates(mol, ring_size=self.ring_size)
|
||||
valid_ring_sizes = sorted({candidate.ring_size for candidate in candidates})
|
||||
if len(candidates) > 1 or len(valid_ring_sizes) > 1:
|
||||
raise AmbiguousMacrolactoneError(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -50,3 +51,22 @@ class FragmentationResult:
|
||||
"fragments": [fragment.to_dict() for fragment in self.fragments],
|
||||
"warnings": list(self.warnings),
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MacrocycleClassificationResult:
|
||||
smiles: str
|
||||
classification: Literal[
|
||||
"standard_macrolactone",
|
||||
"non_standard_macrocycle",
|
||||
"not_macrolactone",
|
||||
]
|
||||
ring_size: int | None
|
||||
primary_reason_code: str | None
|
||||
primary_reason_message: str | None
|
||||
all_reason_codes: list[str] = field(default_factory=list)
|
||||
all_reason_messages: list[str] = field(default_factory=list)
|
||||
candidate_ring_sizes: list[int] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
180
src/macro_lactone_toolkit/visualization.py
Normal file
180
src/macro_lactone_toolkit/visualization.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem.Draw import rdMolDraw2D
|
||||
|
||||
from ._core import build_numbering_result, ensure_mol, find_macrolactone_candidates
|
||||
from .errors import AmbiguousMacrolactoneError, MacrolactoneDetectionError
|
||||
|
||||
|
||||
def numbered_molecule_svg(
|
||||
mol_input: str | Chem.Mol,
|
||||
ring_size: int | None = None,
|
||||
size: tuple[int, int] = (800, 800),
|
||||
allowed_ring_atom_types: list[str] | None = None,
|
||||
show_atom_labels: bool = True,
|
||||
) -> str:
|
||||
mol, _ = ensure_mol(mol_input)
|
||||
numbering = _get_visualization_numbering(
|
||||
mol,
|
||||
ring_size=ring_size,
|
||||
allowed_ring_atom_types=allowed_ring_atom_types,
|
||||
)
|
||||
drawer = rdMolDraw2D.MolDraw2DSVG(*size)
|
||||
_draw_numbered_molecule(
|
||||
mol=mol,
|
||||
drawer=drawer,
|
||||
position_to_atom=numbering.position_to_atom,
|
||||
show_atom_labels=show_atom_labels,
|
||||
)
|
||||
return drawer.GetDrawingText()
|
||||
|
||||
|
||||
def save_numbered_molecule_png(
|
||||
mol_input: str | Chem.Mol,
|
||||
output_path: str | Path,
|
||||
ring_size: int | None = None,
|
||||
size: tuple[int, int] = (800, 800),
|
||||
allowed_ring_atom_types: list[str] | None = None,
|
||||
show_atom_labels: bool = True,
|
||||
dpi: int = 600,
|
||||
) -> Path:
|
||||
del dpi
|
||||
mol, _ = ensure_mol(mol_input)
|
||||
numbering = _get_visualization_numbering(
|
||||
mol,
|
||||
ring_size=ring_size,
|
||||
allowed_ring_atom_types=allowed_ring_atom_types,
|
||||
)
|
||||
drawer = rdMolDraw2D.MolDraw2DCairo(*size)
|
||||
_draw_numbered_molecule(
|
||||
mol=mol,
|
||||
drawer=drawer,
|
||||
position_to_atom=numbering.position_to_atom,
|
||||
show_atom_labels=show_atom_labels,
|
||||
)
|
||||
path = Path(output_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(drawer.GetDrawingText())
|
||||
return path
|
||||
|
||||
|
||||
def fragment_svg(fragment_or_smiles: str | Chem.Mol, size: tuple[int, int] = (400, 400)) -> str:
|
||||
mol, _ = ensure_mol(fragment_or_smiles)
|
||||
drawer = rdMolDraw2D.MolDraw2DSVG(*size)
|
||||
_draw_plain_molecule(mol, drawer)
|
||||
return drawer.GetDrawingText()
|
||||
|
||||
|
||||
def save_fragment_png(
|
||||
fragment_or_smiles: str | Chem.Mol,
|
||||
output_path: str | Path,
|
||||
size: tuple[int, int] = (400, 400),
|
||||
dpi: int = 600,
|
||||
) -> Path:
|
||||
del dpi
|
||||
mol, _ = ensure_mol(fragment_or_smiles)
|
||||
drawer = rdMolDraw2D.MolDraw2DCairo(*size)
|
||||
_draw_plain_molecule(mol, drawer)
|
||||
path = Path(output_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(drawer.GetDrawingText())
|
||||
return path
|
||||
|
||||
|
||||
def _draw_numbered_molecule(
|
||||
mol: Chem.Mol,
|
||||
drawer: rdMolDraw2D.MolDraw2DSVG | rdMolDraw2D.MolDraw2DCairo,
|
||||
position_to_atom: dict[int, int],
|
||||
show_atom_labels: bool,
|
||||
) -> None:
|
||||
draw_mol = Chem.Mol(mol)
|
||||
draw_options = drawer.drawOptions()
|
||||
if show_atom_labels:
|
||||
for position, atom_idx in position_to_atom.items():
|
||||
draw_options.atomLabels[atom_idx] = str(position)
|
||||
|
||||
highlight_atoms = list(position_to_atom.values())
|
||||
highlight_colors = {
|
||||
atom_idx: (0.96, 0.84, 0.48)
|
||||
for atom_idx in highlight_atoms
|
||||
}
|
||||
rdMolDraw2D.PrepareAndDrawMolecule(
|
||||
drawer,
|
||||
draw_mol,
|
||||
highlightAtoms=highlight_atoms,
|
||||
highlightAtomColors=highlight_colors,
|
||||
)
|
||||
drawer.FinishDrawing()
|
||||
|
||||
|
||||
def _draw_plain_molecule(
|
||||
mol: Chem.Mol,
|
||||
drawer: rdMolDraw2D.MolDraw2DSVG | rdMolDraw2D.MolDraw2DCairo,
|
||||
) -> None:
|
||||
rdMolDraw2D.PrepareAndDrawMolecule(drawer, Chem.Mol(mol))
|
||||
drawer.FinishDrawing()
|
||||
|
||||
|
||||
def _get_visualization_numbering(
|
||||
mol: Chem.Mol,
|
||||
ring_size: int | None,
|
||||
allowed_ring_atom_types: list[str] | None,
|
||||
):
|
||||
candidates = find_macrolactone_candidates(mol, ring_size=ring_size)
|
||||
if not candidates:
|
||||
requested = f"{ring_size}-membered " if ring_size is not None else ""
|
||||
raise MacrolactoneDetectionError(f"No valid {requested}macrolactone was detected.")
|
||||
|
||||
if allowed_ring_atom_types is not None:
|
||||
allowed_atomic_numbers = _normalize_allowed_ring_atom_types(allowed_ring_atom_types)
|
||||
candidates = [
|
||||
candidate
|
||||
for candidate in candidates
|
||||
if _candidate_matches_allowed_ring_atom_types(
|
||||
mol,
|
||||
candidate,
|
||||
allowed_atomic_numbers,
|
||||
)
|
||||
]
|
||||
if not candidates:
|
||||
raise ValueError("No macrolactone candidate matched the allowed ring atom types.")
|
||||
|
||||
valid_ring_sizes = sorted({candidate.ring_size for candidate in candidates})
|
||||
if len(candidates) > 1 or len(valid_ring_sizes) > 1:
|
||||
raise AmbiguousMacrolactoneError(
|
||||
"Multiple valid macrolactone candidates were detected. Pass ring_size explicitly."
|
||||
)
|
||||
|
||||
return build_numbering_result(mol, candidates[0])
|
||||
|
||||
|
||||
def _normalize_allowed_ring_atom_types(atom_types: Iterable[str]) -> set[int]:
|
||||
periodic_table = Chem.GetPeriodicTable()
|
||||
normalized: set[int] = set()
|
||||
for atom_type in atom_types:
|
||||
if atom_type.isdigit():
|
||||
normalized.add(int(atom_type))
|
||||
continue
|
||||
atomic_number = periodic_table.GetAtomicNumber(atom_type.capitalize())
|
||||
if atomic_number <= 0:
|
||||
raise ValueError(f"Unsupported atom type: {atom_type}")
|
||||
normalized.add(atomic_number)
|
||||
return normalized
|
||||
|
||||
|
||||
def _candidate_matches_allowed_ring_atom_types(
|
||||
mol: Chem.Mol,
|
||||
candidate,
|
||||
allowed_atomic_numbers: set[int],
|
||||
) -> bool:
|
||||
numbering = build_numbering_result(mol, candidate)
|
||||
for position, atom_idx in numbering.position_to_atom.items():
|
||||
if position == 2:
|
||||
continue
|
||||
if mol.GetAtomWithIdx(atom_idx).GetAtomicNum() not in allowed_atomic_numbers:
|
||||
return False
|
||||
return True
|
||||
180
src/macro_lactone_toolkit/workflows.py
Normal file
180
src/macro_lactone_toolkit/workflows.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from .analyzer import MacroLactoneAnalyzer
|
||||
from .errors import MacrolactoneError
|
||||
from .fragmenter import MacrolactoneFragmenter
|
||||
from .models import FragmentationResult
|
||||
from .visualization import save_numbered_molecule_png
|
||||
|
||||
|
||||
def fragment_csv(
|
||||
input_csv: str | Path,
|
||||
smiles_column: str = "smiles",
|
||||
id_column: str = "id",
|
||||
ring_size: int | None = None,
|
||||
max_rows: int | None = None,
|
||||
) -> list[FragmentationResult]:
|
||||
results, errors = _fragment_csv_with_errors(
|
||||
input_csv=input_csv,
|
||||
smiles_column=smiles_column,
|
||||
id_column=id_column,
|
||||
ring_size=ring_size,
|
||||
max_rows=max_rows,
|
||||
)
|
||||
if errors:
|
||||
first_error = errors[0]
|
||||
raise first_error["exception"]
|
||||
return results
|
||||
|
||||
|
||||
def results_to_dataframe(results: list[FragmentationResult]) -> pd.DataFrame:
|
||||
rows: list[dict] = []
|
||||
for result in results:
|
||||
for fragment in result.fragments:
|
||||
rows.append(
|
||||
{
|
||||
"parent_id": result.parent_id,
|
||||
"parent_smiles": result.parent_smiles,
|
||||
"ring_size": result.ring_size,
|
||||
**fragment.to_dict(),
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def write_result_json(result: FragmentationResult, output_path: str | Path) -> Path:
|
||||
path = Path(output_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(result.to_dict(), indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
||||
return path
|
||||
|
||||
|
||||
def export_numbered_macrolactone_csv(
|
||||
input_csv: str | Path,
|
||||
output_dir: str | Path,
|
||||
smiles_column: str = "smiles",
|
||||
id_column: str = "id",
|
||||
output_csv_name: str = "numbered_macrolactones.csv",
|
||||
ring_size: int | None = None,
|
||||
allowed_ring_atom_types: list[str] | None = None,
|
||||
image_size: tuple[int, int] = (800, 800),
|
||||
dpi: int = 600,
|
||||
) -> Path:
|
||||
analyzer = MacroLactoneAnalyzer()
|
||||
output_dir = Path(output_dir)
|
||||
images_dir = output_dir / "images"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rows = _read_csv_rows(
|
||||
input_csv=input_csv,
|
||||
smiles_column=smiles_column,
|
||||
id_column=id_column,
|
||||
)
|
||||
export_rows: list[dict] = []
|
||||
for row in rows:
|
||||
record = {
|
||||
"parent_id": row["parent_id"],
|
||||
"smiles": row["smiles"],
|
||||
"status": "success",
|
||||
"image_path": "",
|
||||
"classification": None,
|
||||
"primary_reason_code": None,
|
||||
"ring_size": None,
|
||||
"candidate_ring_sizes": [],
|
||||
"error_type": None,
|
||||
"error_message": None,
|
||||
}
|
||||
try:
|
||||
classification = analyzer.classify_macrocycle(row["smiles"], ring_size=ring_size)
|
||||
record.update(
|
||||
{
|
||||
"classification": classification.classification,
|
||||
"primary_reason_code": classification.primary_reason_code,
|
||||
"ring_size": classification.ring_size,
|
||||
"candidate_ring_sizes": classification.candidate_ring_sizes,
|
||||
}
|
||||
)
|
||||
image_path = images_dir / f"{row['parent_id']}.png"
|
||||
save_numbered_molecule_png(
|
||||
row["smiles"],
|
||||
image_path,
|
||||
ring_size=ring_size,
|
||||
size=image_size,
|
||||
allowed_ring_atom_types=allowed_ring_atom_types,
|
||||
dpi=dpi,
|
||||
)
|
||||
record["image_path"] = str(image_path.relative_to(output_dir.parent))
|
||||
except Exception as exc: # pragma: no cover - surfaced in CSV
|
||||
record.update(
|
||||
{
|
||||
"status": "error",
|
||||
"error_type": type(exc).__name__,
|
||||
"error_message": str(exc),
|
||||
}
|
||||
)
|
||||
export_rows.append(record)
|
||||
|
||||
output_path = output_dir / output_csv_name
|
||||
pd.DataFrame(export_rows).to_csv(output_path, index=False)
|
||||
return output_path
|
||||
|
||||
|
||||
def _fragment_csv_with_errors(
|
||||
input_csv: str | Path,
|
||||
smiles_column: str = "smiles",
|
||||
id_column: str = "id",
|
||||
ring_size: int | None = None,
|
||||
max_rows: int | None = None,
|
||||
) -> tuple[list[FragmentationResult], list[dict]]:
|
||||
fragmenter = MacrolactoneFragmenter(ring_size=ring_size)
|
||||
rows = _read_csv_rows(
|
||||
input_csv=input_csv,
|
||||
smiles_column=smiles_column,
|
||||
id_column=id_column,
|
||||
max_rows=max_rows,
|
||||
)
|
||||
|
||||
results: list[FragmentationResult] = []
|
||||
errors: list[dict] = []
|
||||
for row in rows:
|
||||
try:
|
||||
results.append(fragmenter.fragment_molecule(row["smiles"], parent_id=row["parent_id"]))
|
||||
except MacrolactoneError as exc:
|
||||
errors.append(
|
||||
{
|
||||
"parent_id": row["parent_id"],
|
||||
"smiles": row["smiles"],
|
||||
"error_type": type(exc).__name__,
|
||||
"error_message": str(exc),
|
||||
"exception": exc,
|
||||
}
|
||||
)
|
||||
return results, errors
|
||||
|
||||
|
||||
def _read_csv_rows(
|
||||
input_csv: str | Path,
|
||||
smiles_column: str = "smiles",
|
||||
id_column: str = "id",
|
||||
max_rows: int | None = None,
|
||||
) -> list[dict]:
|
||||
dataframe = pd.read_csv(input_csv)
|
||||
if max_rows is not None:
|
||||
dataframe = dataframe.head(max_rows)
|
||||
|
||||
rows = []
|
||||
for index, row in dataframe.iterrows():
|
||||
parent_id = row[id_column] if id_column in dataframe.columns else f"row_{index}"
|
||||
rows.append(
|
||||
{
|
||||
"parent_id": str(parent_id),
|
||||
"smiles": row[smiles_column],
|
||||
}
|
||||
)
|
||||
return rows
|
||||
Reference in New Issue
Block a user