first add
This commit is contained in:
173
examples/05_cheminformatics.py
Normal file
173
examples/05_cheminformatics.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Cheminformatics example: multi-table schema, dataclass interop, CRUD, joins.
|
||||
|
||||
Prereq:
|
||||
- Export SQL_*/PG* env vars to point at your Postgres
|
||||
- Run: uv run python examples/05_cheminformatics.py
|
||||
|
||||
This example does NOT require RDKit; it stores fields you can compute
|
||||
externally (smiles, selfies, qed, sa_score). If RDKit is available, you
|
||||
can compute those before inserting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import SQLModel, Field, Relationship, select
|
||||
|
||||
from sqlmodel_pg_kit.db import get_session
|
||||
from sqlmodel_pg_kit import create_all as _create_all # reuse engine + metadata
|
||||
|
||||
|
||||
# --- Models
|
||||
|
||||
|
||||
class MoleculeDataset(SQLModel, table=True):
|
||||
molecule_id: int = Field(foreign_key="molecule.id", primary_key=True)
|
||||
dataset_id: int = Field(foreign_key="dataset.id", primary_key=True)
|
||||
added_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class Molecule(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
smiles: str = Field(index=True)
|
||||
selfies: Optional[str] = Field(default=None)
|
||||
qed: Optional[float] = Field(default=None, index=True)
|
||||
sa_score: Optional[float] = Field(default=None, index=True)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
datasets: List["Dataset"] = Relationship(back_populates="molecules", link_model=MoleculeDataset)
|
||||
|
||||
|
||||
class Dataset(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
name: str = Field(index=True)
|
||||
|
||||
molecules: List["Molecule"] = Relationship(back_populates="datasets", link_model=MoleculeDataset)
|
||||
|
||||
|
||||
# --- Dataclass DTO (handy for RDKit pipelines)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoleculeDTO:
|
||||
smiles: str
|
||||
selfies: Optional[str] = None
|
||||
qed: Optional[float] = None
|
||||
sa_score: Optional[float] = None
|
||||
|
||||
def to_model(self) -> Molecule:
|
||||
return Molecule(
|
||||
smiles=self.smiles,
|
||||
selfies=self.selfies,
|
||||
qed=self.qed,
|
||||
sa_score=self.sa_score,
|
||||
)
|
||||
|
||||
|
||||
def create_all():
|
||||
# Ensure all tables in this example (plus base kit) exist
|
||||
_create_all()
|
||||
|
||||
|
||||
def main():
|
||||
create_all()
|
||||
|
||||
# Clean existing data for a repeatable run
|
||||
with get_session() as s:
|
||||
s.execute(MoleculeDataset.__table__.delete())
|
||||
s.execute(Molecule.__table__.delete())
|
||||
s.execute(Dataset.__table__.delete())
|
||||
s.commit()
|
||||
|
||||
# Create molecules from dataclass (as you would after RDKit computation)
|
||||
mols = [
|
||||
MoleculeDTO(smiles="CCO", selfies=None, qed=0.45, sa_score=2.1),
|
||||
MoleculeDTO(smiles="c1ccccc1", selfies=None, qed=0.76, sa_score=3.5),
|
||||
MoleculeDTO(smiles="CCN(CC)CC", selfies=None, qed=0.62, sa_score=2.8),
|
||||
]
|
||||
with get_session() as s:
|
||||
for dto in mols:
|
||||
s.add(dto.to_model())
|
||||
s.commit()
|
||||
|
||||
# Create datasets and link molecules (many-to-many)
|
||||
with get_session() as s:
|
||||
ds_train = Dataset(name="train")
|
||||
ds_holdout = Dataset(name="holdout")
|
||||
s.add(ds_train)
|
||||
s.add(ds_holdout)
|
||||
s.commit()
|
||||
s.refresh(ds_train)
|
||||
s.refresh(ds_holdout)
|
||||
|
||||
# Link: first two in train, last one in holdout
|
||||
mol_list: List[Molecule] = s.exec(select(Molecule).order_by(Molecule.id.asc())).all()
|
||||
links = [
|
||||
MoleculeDataset(molecule_id=mol_list[0].id, dataset_id=ds_train.id),
|
||||
MoleculeDataset(molecule_id=mol_list[1].id, dataset_id=ds_train.id),
|
||||
MoleculeDataset(molecule_id=mol_list[2].id, dataset_id=ds_holdout.id),
|
||||
]
|
||||
s.add_all(links)
|
||||
s.commit()
|
||||
|
||||
# CRUD: update a descriptor (e.g., refined QED)
|
||||
with get_session() as s:
|
||||
mol = s.exec(select(Molecule).where(Molecule.smiles == "CCO")).one()
|
||||
mol.qed = 0.50
|
||||
mol.updated_at = datetime.utcnow()
|
||||
s.add(mol)
|
||||
s.commit()
|
||||
s.refresh(mol)
|
||||
print("Updated CCO ->", mol.qed)
|
||||
|
||||
# Filtering: typical queries
|
||||
with get_session() as s:
|
||||
# QED threshold and order by SA score
|
||||
hi_qed = s.exec(
|
||||
select(Molecule).where(Molecule.qed >= 0.6).order_by(Molecule.sa_score.asc())
|
||||
).all()
|
||||
print("qed>=0.6 order by sa_score:", [(m.smiles, m.qed, m.sa_score) for m in hi_qed])
|
||||
|
||||
# Pattern search on SMILES (prefix demo; production use proper search)
|
||||
starts_with_cc = s.exec(select(Molecule).where(Molecule.smiles.like("CC%"))).all()
|
||||
print("SMILES like 'CC%':", [m.smiles for m in starts_with_cc])
|
||||
|
||||
# Joins: list molecules with dataset name (eager-load relationships)
|
||||
with get_session() as s:
|
||||
stmt = (
|
||||
select(Molecule)
|
||||
.options(selectinload(Molecule.datasets))
|
||||
.order_by(Molecule.id.asc())
|
||||
)
|
||||
molecules = s.exec(stmt).all()
|
||||
print("with datasets:", [(m.smiles, [d.name for d in m.datasets]) for m in molecules])
|
||||
|
||||
# Join filter: only molecules in 'train'
|
||||
with get_session() as s:
|
||||
stmt = (
|
||||
select(Molecule)
|
||||
.join(MoleculeDataset, Molecule.id == MoleculeDataset.molecule_id)
|
||||
.join(Dataset, Dataset.id == MoleculeDataset.dataset_id)
|
||||
.where(Dataset.name == "train")
|
||||
.order_by(Molecule.id.asc())
|
||||
)
|
||||
train_mols = s.exec(stmt).all()
|
||||
print("in train:", [m.smiles for m in train_mols])
|
||||
|
||||
# Delete: drop a molecule
|
||||
with get_session() as s:
|
||||
target = s.exec(select(Molecule).where(Molecule.smiles == "CCN(CC)CC")).one()
|
||||
s.delete(target)
|
||||
s.commit()
|
||||
left = s.exec(select(Molecule).order_by(Molecule.id.asc())).all()
|
||||
print("after delete:", [m.smiles for m in left])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user