Files
sqlmodel-pg-kit/notebooks/05_cheminformatics.ipynb
2025-08-17 22:18:45 +08:00

373 lines
13 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cheminformatics Tutorial — Molecules, Datasets, CRUD & Joins\n\n",
"This notebook is a teaching version of `examples/05_cheminformatics.py`.\n",
"It demonstrates:\n",
"- Modeling molecules with descriptors (smiles, selfies, qed, sa_score)\n",
"- Linking molecules to datasets (many-to-many)\n",
"- Dataclass interop for fast inserts\n",
"- Common CRUD, filtering, eager loading, and joins\n",
"- Optional RDKit + Mordred descriptor computation (if installed)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 0. Environment (micromamba)\n",
"In your shell, activate the env before launching Jupyter:\n",
"```bash\n",
"micromamba activate sqlmodel\n",
"jupyter lab # or jupyter notebook\n",
"```\n\n",
"Optional installs inside Jupyter (uncomment to run):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# %pip install -e . pytest\n",
"# Optional cheminformatics packages:\n",
"# %pip install rdkit-pypi mordred\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Configure database connection\n",
"- For quick smoke in-memory SQLite, see the cell below.\n",
"- For PostgreSQL, ensure `SQL_*` or `PG*` env vars are set before starting Jupyter."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sqlmodel_pg_kit import db, create_all as _create_all\n",
"\n",
"# QUICK OPTION: Use SQLite in-memory for learning/demo.\n",
"# Comment these two lines out if you prefer to use Postgres via environment variables.\n",
"db.cfg = db.DatabaseConfig(host='', port=0, user='', password='', database=':memory:', sslmode='disable')\n",
"db.engine = db.create_engine('sqlite:///:memory:', echo=False)\n",
"_create_all() # create base kit models if any\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Define models & dataclass\n",
"We define `Molecule`, `Dataset`, and the link table `MoleculeDataset`.\n",
"We also provide a `MoleculeDTO` dataclass to show how to bring computed values\n",
"(e.g., from RDKit/Mordred pipelines) into SQLModel quickly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"from dataclasses import dataclass\n",
"from datetime import datetime\n",
"from typing import List, Optional\n",
"\n",
"from sqlalchemy.orm import selectinload\n",
"from sqlmodel import SQLModel, Field, Relationship, select\n",
"from sqlmodel_pg_kit.db import get_session, engine\n",
"\n",
"class MoleculeDataset(SQLModel, table=True):\n",
" molecule_id: int = Field(foreign_key='molecule.id', primary_key=True)\n",
" dataset_id: int = Field(foreign_key='dataset.id', primary_key=True)\n",
" added_at: datetime = Field(default_factory=datetime.utcnow)\n",
"\n",
"class Molecule(SQLModel, table=True):\n",
" id: Optional[int] = Field(default=None, primary_key=True)\n",
" smiles: str = Field(index=True)\n",
" selfies: Optional[str] = Field(default=None)\n",
" qed: Optional[float] = Field(default=None, index=True)\n",
" sa_score: Optional[float] = Field(default=None, index=True)\n",
" created_at: datetime = Field(default_factory=datetime.utcnow)\n",
" updated_at: datetime = Field(default_factory=datetime.utcnow)\n",
" datasets: List['Dataset'] = Relationship(back_populates='molecules', link_model=MoleculeDataset)\n",
"\n",
"class Dataset(SQLModel, table=True):\n",
" id: Optional[int] = Field(default=None, primary_key=True)\n",
" name: str = Field(index=True)\n",
" molecules: List['Molecule'] = Relationship(back_populates='datasets', link_model=MoleculeDataset)\n",
"\n",
"@dataclass\n",
"class MoleculeDTO:\n",
" smiles: str\n",
" selfies: Optional[str] = None\n",
" qed: Optional[float] = None\n",
" sa_score: Optional[float] = None\n",
" def to_model(self) -> Molecule:\n",
" return Molecule(**vars(self))\n",
"\n",
"# Create the tables defined above\n",
"SQLModel.metadata.create_all(engine)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Clean slate (idempotent runs)\n",
"We delete existing rows to make this notebook repeatable."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with get_session() as s:\n",
" s.execute(MoleculeDataset.__table__.delete())\n",
" s.execute(Molecule.__table__.delete())\n",
" s.execute(Dataset.__table__.delete())\n",
" s.commit()\n",
"'cleaned'\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Insert molecules via dataclass\n",
"Create a few molecules as you would after computing descriptors upstream."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mols = [\n",
" MoleculeDTO(smiles='CCO', qed=0.45, sa_score=2.1),\n",
" MoleculeDTO(smiles='c1ccccc1', qed=0.76, sa_score=3.5),\n",
" MoleculeDTO(smiles='CCN(CC)CC', qed=0.62, sa_score=2.8),\n",
"]\n",
"with get_session() as s:\n",
" for dto in mols:\n",
" s.add(dto.to_model())\n",
" s.commit()\n",
"\n",
"with get_session() as s:\n",
" inserted = s.exec(select(Molecule).order_by(Molecule.id.asc())).all()\n",
"[(m.id, m.smiles, m.qed, m.sa_score) for m in inserted]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Create datasets and link molecules\n",
"Use a many-to-many link table to assign molecules to `train` or `holdout`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with get_session() as s:\n",
" ds_train = Dataset(name='train')\n",
" ds_holdout = Dataset(name='holdout')\n",
" s.add(ds_train); s.add(ds_holdout); s.commit()\n",
" s.refresh(ds_train); s.refresh(ds_holdout)\n",
" mol_list: List[Molecule] = s.exec(select(Molecule).order_by(Molecule.id.asc())).all()\n",
" links = [\n",
" MoleculeDataset(molecule_id=mol_list[0].id, dataset_id=ds_train.id),\n",
" MoleculeDataset(molecule_id=mol_list[1].id, dataset_id=ds_train.id),\n",
" MoleculeDataset(molecule_id=mol_list[2].id, dataset_id=ds_holdout.id),\n",
" ]\n",
" s.add_all(links); s.commit()\n",
"[(l.molecule_id, l.dataset_id) for l in links]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Update a descriptor (refined QED)\n",
"Typical pattern: load → modify → commit → refresh."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"\n",
"with get_session() as s:\n",
" mol = s.exec(select(Molecule).where(Molecule.smiles=='CCO')).one()\n",
" mol.qed = 0.50\n",
" mol.updated_at = datetime.utcnow()\n",
" s.add(mol); s.commit(); s.refresh(mol)\n",
"(mol.id, mol.smiles, mol.qed)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Filtering and ordering\n",
"Examples: threshold on `qed` and ordering by `sa_score`; prefix search on smiles."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with get_session() as s:\n",
" hi_qed = s.exec(select(Molecule).where(Molecule.qed>=0.6).order_by(Molecule.sa_score.asc())).all()\n",
" hi_qed_view = [(m.smiles, m.qed, m.sa_score) for m in hi_qed]\n",
"\n",
"with get_session() as s:\n",
" starts_with_cc = s.exec(select(Molecule).where(Molecule.smiles.like('CC%'))).all()\n",
" starts_with_cc_view = [m.smiles for m in starts_with_cc]\n",
"\n",
"hi_qed_view, starts_with_cc_view\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Eager loading relationships (avoid N+1)\n",
"Read molecules with their datasets efficiently using `selectinload`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with get_session() as s:\n",
" stmt = select(Molecule).options(selectinload(Molecule.datasets)).order_by(Molecule.id.asc())\n",
" molecules = s.exec(stmt).all()\n",
"[(m.smiles, [d.name for d in m.datasets]) for m in molecules]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9. Join filtering\n",
"Return only molecules that belong to the `train` dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with get_session() as s:\n",
" stmt = (select(Molecule)\n",
" .join(MoleculeDataset, Molecule.id==MoleculeDataset.molecule_id)\n",
" .join(Dataset, Dataset.id==MoleculeDataset.dataset_id)\n",
" .where(Dataset.name=='train')\n",
" .order_by(Molecule.id.asc()))\n",
" train_mols = s.exec(stmt).all()\n",
"[m.smiles for m in train_mols]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10. Delete a molecule\n",
"Load → delete → commit; verify remaining molecules."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with get_session() as s:\n",
" target = s.exec(select(Molecule).where(Molecule.smiles=='CCN(CC)CC')).one()\n",
" s.delete(target); s.commit()\n",
" left = s.exec(select(Molecule).order_by(Molecule.id.asc())).all()\n",
"[m.smiles for m in left]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 11. Optional: RDKit + Mordred computation\n",
"If installed, compute descriptors and update a molecule (e.g., refine QED)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" from rdkit import Chem\n",
" from rdkit.Chem import QED\n",
" from mordred import Calculator, descriptors\n",
" ok = True\n",
"except Exception as e:\n",
" ok = False\n",
" print('RDKit/Mordred not available; skipping.\\n', e)\n",
"\n",
"if ok:\n",
" mol = Chem.MolFromSmiles('c1ccccc1O')\n",
" qed_val = float(QED.qed(mol))\n",
" calc = Calculator(descriptors, ignore_3D=True)\n",
" md = calc(mol)\n",
" num_desc = sum(1 for _ in md.items())\n",
" print('Computed QED:', qed_val, 'Mordred descriptors:', num_desc)\n",
" with get_session() as s:\n",
" m = s.exec(select(Molecule).where(Molecule.smiles=='c1ccccc1O')).first()\n",
" if m is None:\n",
" m = Molecule(smiles='c1ccccc1O', qed=qed_val)\n",
" else:\n",
" m.qed = qed_val\n",
" s.add(m); s.commit(); s.refresh(m)\n",
" (m.id, m.smiles, m.qed)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}