Files
sqlmodel-pg-kit/src/sqlmodel_pg_kit/crud.py
2025-08-17 22:18:45 +08:00

142 lines
4.6 KiB
Python

from __future__ import annotations
from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
from sqlalchemy import insert
from sqlmodel import SQLModel, select
T = TypeVar("T", bound=SQLModel)
class Repository(Generic[T]):
"""Generic synchronous repository for SQLModel models.
Example:
class User(SQLModel, table=True):
...
repo = Repository(User)
with get_session() as s:
user = repo.create(s, {"name": "Alice"})
"""
def __init__(self, model: Type[T]):
self.model = model
# Create ---------------------------------------------------------------
def create(self, session, data: T | Dict[str, Any]) -> T:
obj = data if isinstance(data, self.model) else self.model(**data) # type: ignore[arg-type]
session.add(obj)
session.commit()
session.refresh(obj)
return obj
def bulk_insert(self, session, rows: Iterable[Dict[str, Any]]) -> int:
stmt = insert(self.model).values(list(rows))
res = session.exec(stmt)
session.commit()
return res.rowcount or 0
# Read ----------------------------------------------------------------
def get(self, session, pk: Any) -> Optional[T]:
return session.get(self.model, pk)
def list(
self,
session,
where: Optional[Any] = None,
order_by: Optional[Sequence[Any]] = None,
page: int = 1,
size: int = 20,
) -> List[T]:
off = (page - 1) * size
stmt = select(self.model)
if where is not None:
stmt = stmt.where(where)
if order_by:
for ob in order_by:
stmt = stmt.order_by(ob)
stmt = stmt.offset(off).limit(size)
return session.exec(stmt).all()
# Update ---------------------------------------------------------------
def update(self, session, obj_or_pk: Any, **fields) -> Optional[T]:
obj = obj_or_pk if isinstance(obj_or_pk, self.model) else session.get(self.model, obj_or_pk)
if not obj:
return None
for k, v in fields.items():
setattr(obj, k, v)
session.add(obj)
session.commit()
session.refresh(obj)
return obj
# Delete ---------------------------------------------------------------
def delete(self, session, obj_or_pk: Any) -> bool:
obj = obj_or_pk if isinstance(obj_or_pk, self.model) else session.get(self.model, obj_or_pk)
if not obj:
return False
session.delete(obj)
session.commit()
return True
class AsyncRepository(Generic[T]):
"""Generic asynchronous repository for SQLModel models."""
def __init__(self, model: Type[T]):
self.model = model
async def create(self, session, data: T | Dict[str, Any]) -> T:
obj = data if isinstance(data, self.model) else self.model(**data) # type: ignore[arg-type]
session.add(obj)
await session.commit()
await session.refresh(obj)
return obj
async def bulk_insert(self, session, rows: Iterable[Dict[str, Any]]) -> int:
stmt = insert(self.model).values(list(rows))
res = await session.execute(stmt)
await session.commit()
return res.rowcount or 0
async def get(self, session, pk: Any) -> Optional[T]:
return await session.get(self.model, pk)
async def list(
self,
session,
where: Optional[Any] = None,
order_by: Optional[Sequence[Any]] = None,
page: int = 1,
size: int = 20,
) -> List[T]:
off = (page - 1) * size
stmt = select(self.model)
if where is not None:
stmt = stmt.where(where)
if order_by:
for ob in order_by:
stmt = stmt.order_by(ob)
stmt = stmt.offset(off).limit(size)
res = await session.execute(stmt)
return list(res.scalars().all())
async def update(self, session, obj_or_pk: Any, **fields) -> Optional[T]:
obj = obj_or_pk if isinstance(obj_or_pk, self.model) else await session.get(self.model, obj_or_pk)
if not obj:
return None
for k, v in fields.items():
setattr(obj, k, v)
session.add(obj)
await session.commit()
await session.refresh(obj)
return obj
async def delete(self, session, obj_or_pk: Any) -> bool:
obj = obj_or_pk if isinstance(obj_or_pk, self.model) else await session.get(self.model, obj_or_pk)
if not obj:
return False
await session.delete(obj)
await session.commit()
return True