feat(toolkit): add classification and migration

Implement the standard/non-standard/not-macrolactone classification layer
and integrate it into analyzer, fragmenter, and CLI outputs.

Port the remaining legacy package capabilities into new visualization and
workflow modules, restore batch/statistics/SDF scripts on top of the flat
CSV workflow, and update active docs to the new package API.
This commit is contained in:
2026-03-18 23:56:41 +08:00
parent 9ccbcfcd04
commit c0ead42384
24 changed files with 1497 additions and 313 deletions

View File

@@ -7,19 +7,45 @@ from .errors import (
RingNumberingError,
)
from .fragmenter import MacrolactoneFragmenter
from .models import FragmentationResult, RingNumberingResult, SideChainFragment
from .models import (
FragmentationResult,
MacrocycleClassificationResult,
RingNumberingResult,
SideChainFragment,
)
from .visualization import (
fragment_svg,
numbered_molecule_svg,
save_fragment_png,
save_numbered_molecule_png,
)
from .workflows import (
export_numbered_macrolactone_csv,
fragment_csv,
results_to_dataframe,
write_result_json,
)
__all__ = [
"AmbiguousMacrolactoneError",
"FragmentationError",
"FragmentationResult",
"fragment_csv",
"fragment_svg",
"MacroLactoneAnalyzer",
"MacrolactoneDetectionError",
"MacrolactoneError",
"MacrolactoneFragmenter",
"MacrocycleClassificationResult",
"numbered_molecule_svg",
"RingNumberingError",
"RingNumberingResult",
"results_to_dataframe",
"save_fragment_png",
"save_numbered_molecule_png",
"SideChainFragment",
"export_numbered_macrolactone_csv",
"write_result_json",
]
__version__ = "0.1.0"

View File

@@ -6,11 +6,21 @@ from typing import Iterable
from rdkit import Chem
from .errors import MacrolactoneDetectionError, RingNumberingError
from .models import RingNumberingResult
from .errors import AmbiguousMacrolactoneError, MacrolactoneDetectionError, RingNumberingError
from .models import MacrocycleClassificationResult, RingNumberingResult
VALID_RING_SIZES = tuple(range(12, 21))
REASON_MESSAGES = {
"contains_non_carbon_ring_atoms_outside_positions_1_2": (
"Ring positions 3..N contain non-carbon atoms."
),
"multiple_overlapping_macrocycle_candidates": (
"Overlapping macrolactone candidate rings were detected."
),
"no_lactone_ring_in_12_to_20_range": "No 12-20 membered lactone ring was detected.",
"requested_ring_size_not_found": "The requested ring size was not detected as a lactone ring.",
}
@dataclass(frozen=True)
@@ -73,6 +83,62 @@ def find_macrolactone_candidates(
)
def classify_macrolactone(
mol: Chem.Mol,
smiles: str,
ring_size: int | None = None,
) -> MacrocycleClassificationResult:
candidates = find_macrolactone_candidates(mol, ring_size=ring_size)
candidate_ring_sizes = sorted({candidate.ring_size for candidate in candidates})
if not candidates:
reason_code = (
"requested_ring_size_not_found"
if ring_size is not None
else "no_lactone_ring_in_12_to_20_range"
)
return _build_classification_result(
smiles=smiles,
classification="not_macrolactone",
ring_size=None,
candidate_ring_sizes=[],
reason_codes=[reason_code],
)
if _has_overlapping_candidates(candidates):
return _build_classification_result(
smiles=smiles,
classification="non_standard_macrocycle",
ring_size=candidate_ring_sizes[0] if len(candidate_ring_sizes) == 1 else None,
candidate_ring_sizes=candidate_ring_sizes,
reason_codes=["multiple_overlapping_macrocycle_candidates"],
)
if len(candidates) > 1 or len(candidate_ring_sizes) > 1:
raise AmbiguousMacrolactoneError(
"Multiple valid macrolactone candidates were detected. Pass ring_size explicitly."
)
candidate = candidates[0]
numbering = build_numbering_result(mol, candidate)
if _contains_non_carbon_atoms_outside_positions_1_2(mol, numbering):
return _build_classification_result(
smiles=smiles,
classification="non_standard_macrocycle",
ring_size=candidate.ring_size,
candidate_ring_sizes=candidate_ring_sizes,
reason_codes=["contains_non_carbon_ring_atoms_outside_positions_1_2"],
)
return _build_classification_result(
smiles=smiles,
classification="standard_macrolactone",
ring_size=candidate.ring_size,
candidate_ring_sizes=candidate_ring_sizes,
reason_codes=[],
)
def build_numbering_result(mol: Chem.Mol, candidate: DetectedMacrolactone) -> RingNumberingResult:
ring_atoms = list(candidate.ring_atoms)
ring_atom_set = set(ring_atoms)
@@ -120,6 +186,66 @@ def build_numbering_result(mol: Chem.Mol, candidate: DetectedMacrolactone) -> Ri
)
def _build_classification_result(
smiles: str,
classification: str,
ring_size: int | None,
candidate_ring_sizes: list[int],
reason_codes: list[str],
) -> MacrocycleClassificationResult:
reason_messages = [REASON_MESSAGES[reason_code] for reason_code in reason_codes]
return MacrocycleClassificationResult(
smiles=smiles,
classification=classification,
ring_size=ring_size,
primary_reason_code=reason_codes[0] if reason_codes else None,
primary_reason_message=reason_messages[0] if reason_messages else None,
all_reason_codes=list(reason_codes),
all_reason_messages=reason_messages,
candidate_ring_sizes=list(candidate_ring_sizes),
)
def _contains_non_carbon_atoms_outside_positions_1_2(
mol: Chem.Mol,
numbering: RingNumberingResult,
) -> bool:
for position in range(3, numbering.ring_size + 1):
atom_idx = numbering.position_to_atom[position]
if mol.GetAtomWithIdx(atom_idx).GetAtomicNum() != 6:
return True
return False
def _has_overlapping_candidates(candidates: list[DetectedMacrolactone]) -> bool:
ring_sets = [set(candidate.ring_atoms) for candidate in candidates]
visited: set[int] = set()
for start_index in range(len(candidates)):
if start_index in visited:
continue
queue = deque([start_index])
component_size = 0
while queue:
candidate_index = queue.popleft()
if candidate_index in visited:
continue
visited.add(candidate_index)
component_size += 1
for neighbor_index in range(len(candidates)):
if neighbor_index == candidate_index or neighbor_index in visited:
continue
if ring_sets[candidate_index].intersection(ring_sets[neighbor_index]):
queue.append(neighbor_index)
if component_size > 1:
return True
return False
def collect_side_chain_atoms(
mol: Chem.Mol,
start_atom_idx: int,

View File

@@ -1,8 +1,11 @@
from __future__ import annotations
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Crippen, Descriptors, Lipinski, QED
from ._core import ensure_mol, find_macrolactone_candidates
from ._core import classify_macrolactone, ensure_mol, find_macrolactone_candidates
from .models import MacrocycleClassificationResult
class MacroLactoneAnalyzer:
@@ -13,15 +16,108 @@ class MacroLactoneAnalyzer:
candidates = find_macrolactone_candidates(mol)
return sorted({candidate.ring_size for candidate in candidates})
def analyze_molecule(self, mol_input: str | Chem.Mol) -> dict:
def classify_macrocycle(
self,
mol_input: str | Chem.Mol,
ring_size: int | None = None,
) -> MacrocycleClassificationResult:
mol, smiles = ensure_mol(mol_input)
candidates = find_macrolactone_candidates(mol)
valid_ring_sizes = sorted({candidate.ring_size for candidate in candidates})
is_ambiguous = len(valid_ring_sizes) > 1 or len(candidates) > 1
return {
"smiles": smiles,
"valid_ring_sizes": valid_ring_sizes,
"candidate_count": len(candidates),
"is_ambiguous": is_ambiguous,
"selected_ring_size": valid_ring_sizes[0] if len(valid_ring_sizes) == 1 and len(candidates) == 1 else None,
return classify_macrolactone(mol, smiles=smiles, ring_size=ring_size)
def analyze_molecule(
self,
mol_input: str | Chem.Mol,
ring_size: int | None = None,
) -> dict:
return self.classify_macrocycle(mol_input, ring_size=ring_size).to_dict()
def analyze_many(
self,
smiles_list: list[str],
ring_range: range = range(12, 21),
) -> dict:
classification_counts = {
"standard_macrolactone": 0,
"non_standard_macrocycle": 0,
"not_macrolactone": 0,
}
ring_size_counts = {ring_size: 0 for ring_size in ring_range}
results: list[dict] = []
for smiles in smiles_list:
classification = self.classify_macrocycle(smiles)
classification_counts[classification.classification] += 1
if (
classification.classification == "standard_macrolactone"
and classification.ring_size in ring_size_counts
):
ring_size_counts[classification.ring_size] += 1
results.append(classification.to_dict())
return {
"total": len(smiles_list),
"classification_counts": classification_counts,
"ring_size_counts": ring_size_counts,
"results": results,
}
def classify_dataframe(
self,
dataframe: pd.DataFrame,
smiles_column: str = "smiles",
id_column: str | None = None,
ring_range: range = range(12, 21),
) -> tuple[dict[int, pd.DataFrame], pd.DataFrame]:
grouped_rows: dict[int, list[dict]] = {ring_size: [] for ring_size in ring_range}
rejected_rows: list[dict] = []
for index, row in dataframe.iterrows():
base_row = row.to_dict()
if id_column is None and "id" not in base_row:
base_row["id"] = f"row_{index}"
classification = self.classify_macrocycle(base_row[smiles_column])
enriched_row = {
**base_row,
**classification.to_dict(),
}
if (
classification.classification == "standard_macrolactone"
and classification.ring_size in grouped_rows
):
grouped_rows[classification.ring_size].append(enriched_row)
else:
rejected_rows.append(enriched_row)
grouped_frames = {
ring_size: pd.DataFrame(rows)
for ring_size, rows in grouped_rows.items()
if rows
}
return grouped_frames, pd.DataFrame(rejected_rows)
def match_dynamic_smarts(self, smiles: str, ring_size: int) -> list[int] | None:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
query = Chem.MolFromSmarts(f"[r{ring_size}]([#8][#6](=[#8]))")
if query is None:
return None
matches = mol.GetSubstructMatches(query)
return list(matches[0]) if matches else None
def calculate_properties(self, smiles: str) -> dict[str, float] | None:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
return {
"molecular_weight": Descriptors.MolWt(mol),
"logp": Crippen.MolLogP(mol),
"qed": QED.qed(mol),
"tpsa": Descriptors.TPSA(mol),
"num_atoms": float(mol.GetNumAtoms()),
"num_heavy_atoms": float(mol.GetNumHeavyAtoms()),
"num_h_donors": float(Lipinski.NumHDonors(mol)),
"num_h_acceptors": float(Lipinski.NumHAcceptors(mol)),
"num_rotatable_bonds": float(Lipinski.NumRotatableBonds(mol)),
}

View File

@@ -48,14 +48,14 @@ def build_parser() -> argparse.ArgumentParser:
def run_analyze(args: argparse.Namespace) -> None:
analyzer = MacroLactoneAnalyzer()
if args.smiles:
payload = analyzer.analyze_molecule(args.smiles)
payload = analyzer.analyze_molecule(args.smiles, ring_size=args.ring_size)
_write_output(payload, args.output)
return
rows = _read_csv_rows(args.input, args.smiles_column, args.id_column)
payload = []
for row in rows:
analysis = analyzer.analyze_molecule(row["smiles"])
analysis = analyzer.analyze_molecule(row["smiles"], ring_size=args.ring_size)
analysis["parent_id"] = row["parent_id"]
payload.append(analysis)
_write_output(payload, args.output)

View File

@@ -101,11 +101,15 @@ class MacrolactoneFragmenter:
)
def _select_candidate(self, mol: Chem.Mol):
candidates = find_macrolactone_candidates(mol, ring_size=self.ring_size)
if not candidates:
requested = f"{self.ring_size}-membered " if self.ring_size is not None else ""
raise MacrolactoneDetectionError(f"No valid {requested}macrolactone was detected.")
classification = self.analyzer.classify_macrocycle(mol, ring_size=self.ring_size)
if classification.classification != "standard_macrolactone":
raise MacrolactoneDetectionError(
"Macrolactone rejected: "
f"classification={classification.classification} "
f"primary_reason_code={classification.primary_reason_code}"
)
candidates = find_macrolactone_candidates(mol, ring_size=self.ring_size)
valid_ring_sizes = sorted({candidate.ring_size for candidate in candidates})
if len(candidates) > 1 or len(valid_ring_sizes) > 1:
raise AmbiguousMacrolactoneError(

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from typing import Literal
@dataclass(frozen=True)
@@ -50,3 +51,22 @@ class FragmentationResult:
"fragments": [fragment.to_dict() for fragment in self.fragments],
"warnings": list(self.warnings),
}
@dataclass(frozen=True)
class MacrocycleClassificationResult:
smiles: str
classification: Literal[
"standard_macrolactone",
"non_standard_macrocycle",
"not_macrolactone",
]
ring_size: int | None
primary_reason_code: str | None
primary_reason_message: str | None
all_reason_codes: list[str] = field(default_factory=list)
all_reason_messages: list[str] = field(default_factory=list)
candidate_ring_sizes: list[int] = field(default_factory=list)
def to_dict(self) -> dict:
return asdict(self)

View File

@@ -0,0 +1,180 @@
from __future__ import annotations
from pathlib import Path
from typing import Iterable
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from ._core import build_numbering_result, ensure_mol, find_macrolactone_candidates
from .errors import AmbiguousMacrolactoneError, MacrolactoneDetectionError
def numbered_molecule_svg(
mol_input: str | Chem.Mol,
ring_size: int | None = None,
size: tuple[int, int] = (800, 800),
allowed_ring_atom_types: list[str] | None = None,
show_atom_labels: bool = True,
) -> str:
mol, _ = ensure_mol(mol_input)
numbering = _get_visualization_numbering(
mol,
ring_size=ring_size,
allowed_ring_atom_types=allowed_ring_atom_types,
)
drawer = rdMolDraw2D.MolDraw2DSVG(*size)
_draw_numbered_molecule(
mol=mol,
drawer=drawer,
position_to_atom=numbering.position_to_atom,
show_atom_labels=show_atom_labels,
)
return drawer.GetDrawingText()
def save_numbered_molecule_png(
mol_input: str | Chem.Mol,
output_path: str | Path,
ring_size: int | None = None,
size: tuple[int, int] = (800, 800),
allowed_ring_atom_types: list[str] | None = None,
show_atom_labels: bool = True,
dpi: int = 600,
) -> Path:
del dpi
mol, _ = ensure_mol(mol_input)
numbering = _get_visualization_numbering(
mol,
ring_size=ring_size,
allowed_ring_atom_types=allowed_ring_atom_types,
)
drawer = rdMolDraw2D.MolDraw2DCairo(*size)
_draw_numbered_molecule(
mol=mol,
drawer=drawer,
position_to_atom=numbering.position_to_atom,
show_atom_labels=show_atom_labels,
)
path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(drawer.GetDrawingText())
return path
def fragment_svg(fragment_or_smiles: str | Chem.Mol, size: tuple[int, int] = (400, 400)) -> str:
mol, _ = ensure_mol(fragment_or_smiles)
drawer = rdMolDraw2D.MolDraw2DSVG(*size)
_draw_plain_molecule(mol, drawer)
return drawer.GetDrawingText()
def save_fragment_png(
fragment_or_smiles: str | Chem.Mol,
output_path: str | Path,
size: tuple[int, int] = (400, 400),
dpi: int = 600,
) -> Path:
del dpi
mol, _ = ensure_mol(fragment_or_smiles)
drawer = rdMolDraw2D.MolDraw2DCairo(*size)
_draw_plain_molecule(mol, drawer)
path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(drawer.GetDrawingText())
return path
def _draw_numbered_molecule(
mol: Chem.Mol,
drawer: rdMolDraw2D.MolDraw2DSVG | rdMolDraw2D.MolDraw2DCairo,
position_to_atom: dict[int, int],
show_atom_labels: bool,
) -> None:
draw_mol = Chem.Mol(mol)
draw_options = drawer.drawOptions()
if show_atom_labels:
for position, atom_idx in position_to_atom.items():
draw_options.atomLabels[atom_idx] = str(position)
highlight_atoms = list(position_to_atom.values())
highlight_colors = {
atom_idx: (0.96, 0.84, 0.48)
for atom_idx in highlight_atoms
}
rdMolDraw2D.PrepareAndDrawMolecule(
drawer,
draw_mol,
highlightAtoms=highlight_atoms,
highlightAtomColors=highlight_colors,
)
drawer.FinishDrawing()
def _draw_plain_molecule(
mol: Chem.Mol,
drawer: rdMolDraw2D.MolDraw2DSVG | rdMolDraw2D.MolDraw2DCairo,
) -> None:
rdMolDraw2D.PrepareAndDrawMolecule(drawer, Chem.Mol(mol))
drawer.FinishDrawing()
def _get_visualization_numbering(
mol: Chem.Mol,
ring_size: int | None,
allowed_ring_atom_types: list[str] | None,
):
candidates = find_macrolactone_candidates(mol, ring_size=ring_size)
if not candidates:
requested = f"{ring_size}-membered " if ring_size is not None else ""
raise MacrolactoneDetectionError(f"No valid {requested}macrolactone was detected.")
if allowed_ring_atom_types is not None:
allowed_atomic_numbers = _normalize_allowed_ring_atom_types(allowed_ring_atom_types)
candidates = [
candidate
for candidate in candidates
if _candidate_matches_allowed_ring_atom_types(
mol,
candidate,
allowed_atomic_numbers,
)
]
if not candidates:
raise ValueError("No macrolactone candidate matched the allowed ring atom types.")
valid_ring_sizes = sorted({candidate.ring_size for candidate in candidates})
if len(candidates) > 1 or len(valid_ring_sizes) > 1:
raise AmbiguousMacrolactoneError(
"Multiple valid macrolactone candidates were detected. Pass ring_size explicitly."
)
return build_numbering_result(mol, candidates[0])
def _normalize_allowed_ring_atom_types(atom_types: Iterable[str]) -> set[int]:
periodic_table = Chem.GetPeriodicTable()
normalized: set[int] = set()
for atom_type in atom_types:
if atom_type.isdigit():
normalized.add(int(atom_type))
continue
atomic_number = periodic_table.GetAtomicNumber(atom_type.capitalize())
if atomic_number <= 0:
raise ValueError(f"Unsupported atom type: {atom_type}")
normalized.add(atomic_number)
return normalized
def _candidate_matches_allowed_ring_atom_types(
mol: Chem.Mol,
candidate,
allowed_atomic_numbers: set[int],
) -> bool:
numbering = build_numbering_result(mol, candidate)
for position, atom_idx in numbering.position_to_atom.items():
if position == 2:
continue
if mol.GetAtomWithIdx(atom_idx).GetAtomicNum() not in allowed_atomic_numbers:
return False
return True

View File

@@ -0,0 +1,180 @@
from __future__ import annotations
import json
from pathlib import Path
import pandas as pd
from .analyzer import MacroLactoneAnalyzer
from .errors import MacrolactoneError
from .fragmenter import MacrolactoneFragmenter
from .models import FragmentationResult
from .visualization import save_numbered_molecule_png
def fragment_csv(
input_csv: str | Path,
smiles_column: str = "smiles",
id_column: str = "id",
ring_size: int | None = None,
max_rows: int | None = None,
) -> list[FragmentationResult]:
results, errors = _fragment_csv_with_errors(
input_csv=input_csv,
smiles_column=smiles_column,
id_column=id_column,
ring_size=ring_size,
max_rows=max_rows,
)
if errors:
first_error = errors[0]
raise first_error["exception"]
return results
def results_to_dataframe(results: list[FragmentationResult]) -> pd.DataFrame:
rows: list[dict] = []
for result in results:
for fragment in result.fragments:
rows.append(
{
"parent_id": result.parent_id,
"parent_smiles": result.parent_smiles,
"ring_size": result.ring_size,
**fragment.to_dict(),
}
)
return pd.DataFrame(rows)
def write_result_json(result: FragmentationResult, output_path: str | Path) -> Path:
path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(result.to_dict(), indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
return path
def export_numbered_macrolactone_csv(
input_csv: str | Path,
output_dir: str | Path,
smiles_column: str = "smiles",
id_column: str = "id",
output_csv_name: str = "numbered_macrolactones.csv",
ring_size: int | None = None,
allowed_ring_atom_types: list[str] | None = None,
image_size: tuple[int, int] = (800, 800),
dpi: int = 600,
) -> Path:
analyzer = MacroLactoneAnalyzer()
output_dir = Path(output_dir)
images_dir = output_dir / "images"
output_dir.mkdir(parents=True, exist_ok=True)
images_dir.mkdir(parents=True, exist_ok=True)
rows = _read_csv_rows(
input_csv=input_csv,
smiles_column=smiles_column,
id_column=id_column,
)
export_rows: list[dict] = []
for row in rows:
record = {
"parent_id": row["parent_id"],
"smiles": row["smiles"],
"status": "success",
"image_path": "",
"classification": None,
"primary_reason_code": None,
"ring_size": None,
"candidate_ring_sizes": [],
"error_type": None,
"error_message": None,
}
try:
classification = analyzer.classify_macrocycle(row["smiles"], ring_size=ring_size)
record.update(
{
"classification": classification.classification,
"primary_reason_code": classification.primary_reason_code,
"ring_size": classification.ring_size,
"candidate_ring_sizes": classification.candidate_ring_sizes,
}
)
image_path = images_dir / f"{row['parent_id']}.png"
save_numbered_molecule_png(
row["smiles"],
image_path,
ring_size=ring_size,
size=image_size,
allowed_ring_atom_types=allowed_ring_atom_types,
dpi=dpi,
)
record["image_path"] = str(image_path.relative_to(output_dir.parent))
except Exception as exc: # pragma: no cover - surfaced in CSV
record.update(
{
"status": "error",
"error_type": type(exc).__name__,
"error_message": str(exc),
}
)
export_rows.append(record)
output_path = output_dir / output_csv_name
pd.DataFrame(export_rows).to_csv(output_path, index=False)
return output_path
def _fragment_csv_with_errors(
input_csv: str | Path,
smiles_column: str = "smiles",
id_column: str = "id",
ring_size: int | None = None,
max_rows: int | None = None,
) -> tuple[list[FragmentationResult], list[dict]]:
fragmenter = MacrolactoneFragmenter(ring_size=ring_size)
rows = _read_csv_rows(
input_csv=input_csv,
smiles_column=smiles_column,
id_column=id_column,
max_rows=max_rows,
)
results: list[FragmentationResult] = []
errors: list[dict] = []
for row in rows:
try:
results.append(fragmenter.fragment_molecule(row["smiles"], parent_id=row["parent_id"]))
except MacrolactoneError as exc:
errors.append(
{
"parent_id": row["parent_id"],
"smiles": row["smiles"],
"error_type": type(exc).__name__,
"error_message": str(exc),
"exception": exc,
}
)
return results, errors
def _read_csv_rows(
input_csv: str | Path,
smiles_column: str = "smiles",
id_column: str = "id",
max_rows: int | None = None,
) -> list[dict]:
dataframe = pd.read_csv(input_csv)
if max_rows is not None:
dataframe = dataframe.head(max_rows)
rows = []
for index, row in dataframe.iterrows():
parent_id = row[id_column] if id_column in dataframe.columns else f"row_{index}"
rows.append(
{
"parent_id": str(parent_id),
"smiles": row[smiles_column],
}
)
return rows