Files
sqlmodel-pg-kit/examples/05_cheminformatics.py
2025-08-17 22:18:45 +08:00

174 lines
5.8 KiB
Python

"""
Cheminformatics example: multi-table schema, dataclass interop, CRUD, joins.
Prereq:
- Export SQL_*/PG* env vars to point at your Postgres
- Run: uv run python examples/05_cheminformatics.py
This example does NOT require RDKit; it stores fields you can compute
externally (smiles, selfies, qed, sa_score). If RDKit is available, you
can compute those before inserting.
"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional
from sqlalchemy.orm import selectinload
from sqlmodel import SQLModel, Field, Relationship, select
from sqlmodel_pg_kit.db import get_session
from sqlmodel_pg_kit import create_all as _create_all # reuse engine + metadata
# --- Models
class MoleculeDataset(SQLModel, table=True):
molecule_id: int = Field(foreign_key="molecule.id", primary_key=True)
dataset_id: int = Field(foreign_key="dataset.id", primary_key=True)
added_at: datetime = Field(default_factory=datetime.utcnow)
class Molecule(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
smiles: str = Field(index=True)
selfies: Optional[str] = Field(default=None)
qed: Optional[float] = Field(default=None, index=True)
sa_score: Optional[float] = Field(default=None, index=True)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
datasets: List["Dataset"] = Relationship(back_populates="molecules", link_model=MoleculeDataset)
class Dataset(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True)
molecules: List["Molecule"] = Relationship(back_populates="datasets", link_model=MoleculeDataset)
# --- Dataclass DTO (handy for RDKit pipelines)
@dataclass
class MoleculeDTO:
smiles: str
selfies: Optional[str] = None
qed: Optional[float] = None
sa_score: Optional[float] = None
def to_model(self) -> Molecule:
return Molecule(
smiles=self.smiles,
selfies=self.selfies,
qed=self.qed,
sa_score=self.sa_score,
)
def create_all():
# Ensure all tables in this example (plus base kit) exist
_create_all()
def main():
create_all()
# Clean existing data for a repeatable run
with get_session() as s:
s.execute(MoleculeDataset.__table__.delete())
s.execute(Molecule.__table__.delete())
s.execute(Dataset.__table__.delete())
s.commit()
# Create molecules from dataclass (as you would after RDKit computation)
mols = [
MoleculeDTO(smiles="CCO", selfies=None, qed=0.45, sa_score=2.1),
MoleculeDTO(smiles="c1ccccc1", selfies=None, qed=0.76, sa_score=3.5),
MoleculeDTO(smiles="CCN(CC)CC", selfies=None, qed=0.62, sa_score=2.8),
]
with get_session() as s:
for dto in mols:
s.add(dto.to_model())
s.commit()
# Create datasets and link molecules (many-to-many)
with get_session() as s:
ds_train = Dataset(name="train")
ds_holdout = Dataset(name="holdout")
s.add(ds_train)
s.add(ds_holdout)
s.commit()
s.refresh(ds_train)
s.refresh(ds_holdout)
# Link: first two in train, last one in holdout
mol_list: List[Molecule] = s.exec(select(Molecule).order_by(Molecule.id.asc())).all()
links = [
MoleculeDataset(molecule_id=mol_list[0].id, dataset_id=ds_train.id),
MoleculeDataset(molecule_id=mol_list[1].id, dataset_id=ds_train.id),
MoleculeDataset(molecule_id=mol_list[2].id, dataset_id=ds_holdout.id),
]
s.add_all(links)
s.commit()
# CRUD: update a descriptor (e.g., refined QED)
with get_session() as s:
mol = s.exec(select(Molecule).where(Molecule.smiles == "CCO")).one()
mol.qed = 0.50
mol.updated_at = datetime.utcnow()
s.add(mol)
s.commit()
s.refresh(mol)
print("Updated CCO ->", mol.qed)
# Filtering: typical queries
with get_session() as s:
# QED threshold and order by SA score
hi_qed = s.exec(
select(Molecule).where(Molecule.qed >= 0.6).order_by(Molecule.sa_score.asc())
).all()
print("qed>=0.6 order by sa_score:", [(m.smiles, m.qed, m.sa_score) for m in hi_qed])
# Pattern search on SMILES (prefix demo; production use proper search)
starts_with_cc = s.exec(select(Molecule).where(Molecule.smiles.like("CC%"))).all()
print("SMILES like 'CC%':", [m.smiles for m in starts_with_cc])
# Joins: list molecules with dataset name (eager-load relationships)
with get_session() as s:
stmt = (
select(Molecule)
.options(selectinload(Molecule.datasets))
.order_by(Molecule.id.asc())
)
molecules = s.exec(stmt).all()
print("with datasets:", [(m.smiles, [d.name for d in m.datasets]) for m in molecules])
# Join filter: only molecules in 'train'
with get_session() as s:
stmt = (
select(Molecule)
.join(MoleculeDataset, Molecule.id == MoleculeDataset.molecule_id)
.join(Dataset, Dataset.id == MoleculeDataset.dataset_id)
.where(Dataset.name == "train")
.order_by(Molecule.id.asc())
)
train_mols = s.exec(stmt).all()
print("in train:", [m.smiles for m in train_mols])
# Delete: drop a molecule
with get_session() as s:
target = s.exec(select(Molecule).where(Molecule.smiles == "CCN(CC)CC")).one()
s.delete(target)
s.commit()
left = s.exec(select(Molecule).order_by(Molecule.id.asc())).all()
print("after delete:", [m.smiles for m in left])
if __name__ == "__main__":
main()