142 lines
4.6 KiB
Python
142 lines
4.6 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
|
|
|
|
from sqlalchemy import insert
|
|
from sqlmodel import SQLModel, select
|
|
|
|
T = TypeVar("T", bound=SQLModel)
|
|
|
|
|
|
class Repository(Generic[T]):
|
|
"""Generic synchronous repository for SQLModel models.
|
|
|
|
Example:
|
|
class User(SQLModel, table=True):
|
|
...
|
|
repo = Repository(User)
|
|
with get_session() as s:
|
|
user = repo.create(s, {"name": "Alice"})
|
|
"""
|
|
|
|
def __init__(self, model: Type[T]):
|
|
self.model = model
|
|
|
|
# Create ---------------------------------------------------------------
|
|
def create(self, session, data: T | Dict[str, Any]) -> T:
|
|
obj = data if isinstance(data, self.model) else self.model(**data) # type: ignore[arg-type]
|
|
session.add(obj)
|
|
session.commit()
|
|
session.refresh(obj)
|
|
return obj
|
|
|
|
def bulk_insert(self, session, rows: Iterable[Dict[str, Any]]) -> int:
|
|
stmt = insert(self.model).values(list(rows))
|
|
res = session.exec(stmt)
|
|
session.commit()
|
|
return res.rowcount or 0
|
|
|
|
# Read ----------------------------------------------------------------
|
|
def get(self, session, pk: Any) -> Optional[T]:
|
|
return session.get(self.model, pk)
|
|
|
|
def list(
|
|
self,
|
|
session,
|
|
where: Optional[Any] = None,
|
|
order_by: Optional[Sequence[Any]] = None,
|
|
page: int = 1,
|
|
size: int = 20,
|
|
) -> List[T]:
|
|
off = (page - 1) * size
|
|
stmt = select(self.model)
|
|
if where is not None:
|
|
stmt = stmt.where(where)
|
|
if order_by:
|
|
for ob in order_by:
|
|
stmt = stmt.order_by(ob)
|
|
stmt = stmt.offset(off).limit(size)
|
|
return session.exec(stmt).all()
|
|
|
|
# Update ---------------------------------------------------------------
|
|
def update(self, session, obj_or_pk: Any, **fields) -> Optional[T]:
|
|
obj = obj_or_pk if isinstance(obj_or_pk, self.model) else session.get(self.model, obj_or_pk)
|
|
if not obj:
|
|
return None
|
|
for k, v in fields.items():
|
|
setattr(obj, k, v)
|
|
session.add(obj)
|
|
session.commit()
|
|
session.refresh(obj)
|
|
return obj
|
|
|
|
# Delete ---------------------------------------------------------------
|
|
def delete(self, session, obj_or_pk: Any) -> bool:
|
|
obj = obj_or_pk if isinstance(obj_or_pk, self.model) else session.get(self.model, obj_or_pk)
|
|
if not obj:
|
|
return False
|
|
session.delete(obj)
|
|
session.commit()
|
|
return True
|
|
|
|
|
|
class AsyncRepository(Generic[T]):
|
|
"""Generic asynchronous repository for SQLModel models."""
|
|
|
|
def __init__(self, model: Type[T]):
|
|
self.model = model
|
|
|
|
async def create(self, session, data: T | Dict[str, Any]) -> T:
|
|
obj = data if isinstance(data, self.model) else self.model(**data) # type: ignore[arg-type]
|
|
session.add(obj)
|
|
await session.commit()
|
|
await session.refresh(obj)
|
|
return obj
|
|
|
|
async def bulk_insert(self, session, rows: Iterable[Dict[str, Any]]) -> int:
|
|
stmt = insert(self.model).values(list(rows))
|
|
res = await session.execute(stmt)
|
|
await session.commit()
|
|
return res.rowcount or 0
|
|
|
|
async def get(self, session, pk: Any) -> Optional[T]:
|
|
return await session.get(self.model, pk)
|
|
|
|
async def list(
|
|
self,
|
|
session,
|
|
where: Optional[Any] = None,
|
|
order_by: Optional[Sequence[Any]] = None,
|
|
page: int = 1,
|
|
size: int = 20,
|
|
) -> List[T]:
|
|
off = (page - 1) * size
|
|
stmt = select(self.model)
|
|
if where is not None:
|
|
stmt = stmt.where(where)
|
|
if order_by:
|
|
for ob in order_by:
|
|
stmt = stmt.order_by(ob)
|
|
stmt = stmt.offset(off).limit(size)
|
|
res = await session.execute(stmt)
|
|
return list(res.scalars().all())
|
|
|
|
async def update(self, session, obj_or_pk: Any, **fields) -> Optional[T]:
|
|
obj = obj_or_pk if isinstance(obj_or_pk, self.model) else await session.get(self.model, obj_or_pk)
|
|
if not obj:
|
|
return None
|
|
for k, v in fields.items():
|
|
setattr(obj, k, v)
|
|
session.add(obj)
|
|
await session.commit()
|
|
await session.refresh(obj)
|
|
return obj
|
|
|
|
async def delete(self, session, obj_or_pk: Any) -> bool:
|
|
obj = obj_or_pk if isinstance(obj_or_pk, self.model) else await session.get(self.model, obj_or_pk)
|
|
if not obj:
|
|
return False
|
|
await session.delete(obj)
|
|
await session.commit()
|
|
return True
|