""" Cheminformatics example: multi-table schema, dataclass interop, CRUD, joins. Prereq: - Export SQL_*/PG* env vars to point at your Postgres - Run: uv run python examples/05_cheminformatics.py This example does NOT require RDKit; it stores fields you can compute externally (smiles, selfies, qed, sa_score). If RDKit is available, you can compute those before inserting. """ from __future__ import annotations from dataclasses import dataclass from datetime import datetime from typing import List, Optional from sqlalchemy.orm import selectinload from sqlmodel import SQLModel, Field, Relationship, select from sqlmodel_pg_kit.db import get_session from sqlmodel_pg_kit import create_all as _create_all # reuse engine + metadata # --- Models class MoleculeDataset(SQLModel, table=True): molecule_id: int = Field(foreign_key="molecule.id", primary_key=True) dataset_id: int = Field(foreign_key="dataset.id", primary_key=True) added_at: datetime = Field(default_factory=datetime.utcnow) class Molecule(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) smiles: str = Field(index=True) selfies: Optional[str] = Field(default=None) qed: Optional[float] = Field(default=None, index=True) sa_score: Optional[float] = Field(default=None, index=True) created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) datasets: List["Dataset"] = Relationship(back_populates="molecules", link_model=MoleculeDataset) class Dataset(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str = Field(index=True) molecules: List["Molecule"] = Relationship(back_populates="datasets", link_model=MoleculeDataset) # --- Dataclass DTO (handy for RDKit pipelines) @dataclass class MoleculeDTO: smiles: str selfies: Optional[str] = None qed: Optional[float] = None sa_score: Optional[float] = None def to_model(self) -> Molecule: return Molecule( smiles=self.smiles, selfies=self.selfies, qed=self.qed, sa_score=self.sa_score, ) def create_all(): # Ensure all tables in this example (plus base kit) exist _create_all() def main(): create_all() # Clean existing data for a repeatable run with get_session() as s: s.execute(MoleculeDataset.__table__.delete()) s.execute(Molecule.__table__.delete()) s.execute(Dataset.__table__.delete()) s.commit() # Create molecules from dataclass (as you would after RDKit computation) mols = [ MoleculeDTO(smiles="CCO", selfies=None, qed=0.45, sa_score=2.1), MoleculeDTO(smiles="c1ccccc1", selfies=None, qed=0.76, sa_score=3.5), MoleculeDTO(smiles="CCN(CC)CC", selfies=None, qed=0.62, sa_score=2.8), ] with get_session() as s: for dto in mols: s.add(dto.to_model()) s.commit() # Create datasets and link molecules (many-to-many) with get_session() as s: ds_train = Dataset(name="train") ds_holdout = Dataset(name="holdout") s.add(ds_train) s.add(ds_holdout) s.commit() s.refresh(ds_train) s.refresh(ds_holdout) # Link: first two in train, last one in holdout mol_list: List[Molecule] = s.exec(select(Molecule).order_by(Molecule.id.asc())).all() links = [ MoleculeDataset(molecule_id=mol_list[0].id, dataset_id=ds_train.id), MoleculeDataset(molecule_id=mol_list[1].id, dataset_id=ds_train.id), MoleculeDataset(molecule_id=mol_list[2].id, dataset_id=ds_holdout.id), ] s.add_all(links) s.commit() # CRUD: update a descriptor (e.g., refined QED) with get_session() as s: mol = s.exec(select(Molecule).where(Molecule.smiles == "CCO")).one() mol.qed = 0.50 mol.updated_at = datetime.utcnow() s.add(mol) s.commit() s.refresh(mol) print("Updated CCO ->", mol.qed) # Filtering: typical queries with get_session() as s: # QED threshold and order by SA score hi_qed = s.exec( select(Molecule).where(Molecule.qed >= 0.6).order_by(Molecule.sa_score.asc()) ).all() print("qed>=0.6 order by sa_score:", [(m.smiles, m.qed, m.sa_score) for m in hi_qed]) # Pattern search on SMILES (prefix demo; production use proper search) starts_with_cc = s.exec(select(Molecule).where(Molecule.smiles.like("CC%"))).all() print("SMILES like 'CC%':", [m.smiles for m in starts_with_cc]) # Joins: list molecules with dataset name (eager-load relationships) with get_session() as s: stmt = ( select(Molecule) .options(selectinload(Molecule.datasets)) .order_by(Molecule.id.asc()) ) molecules = s.exec(stmt).all() print("with datasets:", [(m.smiles, [d.name for d in m.datasets]) for m in molecules]) # Join filter: only molecules in 'train' with get_session() as s: stmt = ( select(Molecule) .join(MoleculeDataset, Molecule.id == MoleculeDataset.molecule_id) .join(Dataset, Dataset.id == MoleculeDataset.dataset_id) .where(Dataset.name == "train") .order_by(Molecule.id.asc()) ) train_mols = s.exec(stmt).all() print("in train:", [m.smiles for m in train_mols]) # Delete: drop a molecule with get_session() as s: target = s.exec(select(Molecule).where(Molecule.smiles == "CCN(CC)CC")).one() s.delete(target) s.commit() left = s.exec(select(Molecule).order_by(Molecule.id.asc())).all() print("after delete:", [m.smiles for m in left]) if __name__ == "__main__": main()