174 lines
5.8 KiB
Python
174 lines
5.8 KiB
Python
"""
|
|
Cheminformatics example: multi-table schema, dataclass interop, CRUD, joins.
|
|
|
|
Prereq:
|
|
- Export SQL_*/PG* env vars to point at your Postgres
|
|
- Run: uv run python examples/05_cheminformatics.py
|
|
|
|
This example does NOT require RDKit; it stores fields you can compute
|
|
externally (smiles, selfies, qed, sa_score). If RDKit is available, you
|
|
can compute those before inserting.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import List, Optional
|
|
|
|
from sqlalchemy.orm import selectinload
|
|
from sqlmodel import SQLModel, Field, Relationship, select
|
|
|
|
from sqlmodel_pg_kit.db import get_session
|
|
from sqlmodel_pg_kit import create_all as _create_all # reuse engine + metadata
|
|
|
|
|
|
# --- Models
|
|
|
|
|
|
class MoleculeDataset(SQLModel, table=True):
|
|
molecule_id: int = Field(foreign_key="molecule.id", primary_key=True)
|
|
dataset_id: int = Field(foreign_key="dataset.id", primary_key=True)
|
|
added_at: datetime = Field(default_factory=datetime.utcnow)
|
|
|
|
|
|
class Molecule(SQLModel, table=True):
|
|
id: Optional[int] = Field(default=None, primary_key=True)
|
|
smiles: str = Field(index=True)
|
|
selfies: Optional[str] = Field(default=None)
|
|
qed: Optional[float] = Field(default=None, index=True)
|
|
sa_score: Optional[float] = Field(default=None, index=True)
|
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
|
|
|
datasets: List["Dataset"] = Relationship(back_populates="molecules", link_model=MoleculeDataset)
|
|
|
|
|
|
class Dataset(SQLModel, table=True):
|
|
id: Optional[int] = Field(default=None, primary_key=True)
|
|
name: str = Field(index=True)
|
|
|
|
molecules: List["Molecule"] = Relationship(back_populates="datasets", link_model=MoleculeDataset)
|
|
|
|
|
|
# --- Dataclass DTO (handy for RDKit pipelines)
|
|
|
|
|
|
@dataclass
|
|
class MoleculeDTO:
|
|
smiles: str
|
|
selfies: Optional[str] = None
|
|
qed: Optional[float] = None
|
|
sa_score: Optional[float] = None
|
|
|
|
def to_model(self) -> Molecule:
|
|
return Molecule(
|
|
smiles=self.smiles,
|
|
selfies=self.selfies,
|
|
qed=self.qed,
|
|
sa_score=self.sa_score,
|
|
)
|
|
|
|
|
|
def create_all():
|
|
# Ensure all tables in this example (plus base kit) exist
|
|
_create_all()
|
|
|
|
|
|
def main():
|
|
create_all()
|
|
|
|
# Clean existing data for a repeatable run
|
|
with get_session() as s:
|
|
s.execute(MoleculeDataset.__table__.delete())
|
|
s.execute(Molecule.__table__.delete())
|
|
s.execute(Dataset.__table__.delete())
|
|
s.commit()
|
|
|
|
# Create molecules from dataclass (as you would after RDKit computation)
|
|
mols = [
|
|
MoleculeDTO(smiles="CCO", selfies=None, qed=0.45, sa_score=2.1),
|
|
MoleculeDTO(smiles="c1ccccc1", selfies=None, qed=0.76, sa_score=3.5),
|
|
MoleculeDTO(smiles="CCN(CC)CC", selfies=None, qed=0.62, sa_score=2.8),
|
|
]
|
|
with get_session() as s:
|
|
for dto in mols:
|
|
s.add(dto.to_model())
|
|
s.commit()
|
|
|
|
# Create datasets and link molecules (many-to-many)
|
|
with get_session() as s:
|
|
ds_train = Dataset(name="train")
|
|
ds_holdout = Dataset(name="holdout")
|
|
s.add(ds_train)
|
|
s.add(ds_holdout)
|
|
s.commit()
|
|
s.refresh(ds_train)
|
|
s.refresh(ds_holdout)
|
|
|
|
# Link: first two in train, last one in holdout
|
|
mol_list: List[Molecule] = s.exec(select(Molecule).order_by(Molecule.id.asc())).all()
|
|
links = [
|
|
MoleculeDataset(molecule_id=mol_list[0].id, dataset_id=ds_train.id),
|
|
MoleculeDataset(molecule_id=mol_list[1].id, dataset_id=ds_train.id),
|
|
MoleculeDataset(molecule_id=mol_list[2].id, dataset_id=ds_holdout.id),
|
|
]
|
|
s.add_all(links)
|
|
s.commit()
|
|
|
|
# CRUD: update a descriptor (e.g., refined QED)
|
|
with get_session() as s:
|
|
mol = s.exec(select(Molecule).where(Molecule.smiles == "CCO")).one()
|
|
mol.qed = 0.50
|
|
mol.updated_at = datetime.utcnow()
|
|
s.add(mol)
|
|
s.commit()
|
|
s.refresh(mol)
|
|
print("Updated CCO ->", mol.qed)
|
|
|
|
# Filtering: typical queries
|
|
with get_session() as s:
|
|
# QED threshold and order by SA score
|
|
hi_qed = s.exec(
|
|
select(Molecule).where(Molecule.qed >= 0.6).order_by(Molecule.sa_score.asc())
|
|
).all()
|
|
print("qed>=0.6 order by sa_score:", [(m.smiles, m.qed, m.sa_score) for m in hi_qed])
|
|
|
|
# Pattern search on SMILES (prefix demo; production use proper search)
|
|
starts_with_cc = s.exec(select(Molecule).where(Molecule.smiles.like("CC%"))).all()
|
|
print("SMILES like 'CC%':", [m.smiles for m in starts_with_cc])
|
|
|
|
# Joins: list molecules with dataset name (eager-load relationships)
|
|
with get_session() as s:
|
|
stmt = (
|
|
select(Molecule)
|
|
.options(selectinload(Molecule.datasets))
|
|
.order_by(Molecule.id.asc())
|
|
)
|
|
molecules = s.exec(stmt).all()
|
|
print("with datasets:", [(m.smiles, [d.name for d in m.datasets]) for m in molecules])
|
|
|
|
# Join filter: only molecules in 'train'
|
|
with get_session() as s:
|
|
stmt = (
|
|
select(Molecule)
|
|
.join(MoleculeDataset, Molecule.id == MoleculeDataset.molecule_id)
|
|
.join(Dataset, Dataset.id == MoleculeDataset.dataset_id)
|
|
.where(Dataset.name == "train")
|
|
.order_by(Molecule.id.asc())
|
|
)
|
|
train_mols = s.exec(stmt).all()
|
|
print("in train:", [m.smiles for m in train_mols])
|
|
|
|
# Delete: drop a molecule
|
|
with get_session() as s:
|
|
target = s.exec(select(Molecule).where(Molecule.smiles == "CCN(CC)CC")).one()
|
|
s.delete(target)
|
|
s.commit()
|
|
left = s.exec(select(Molecule).order_by(Molecule.id.asc())).all()
|
|
print("after delete:", [m.smiles for m in left])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|