refactor: update module docstrings for clarity and consistency
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Database helpers and abstractions for backend persistence."""
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
"""Typed wrapper around fastapi-pagination for backend query helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import Any, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
|
||||
from fastapi_pagination.ext.sqlalchemy import paginate as _paginate
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
Transformer = Callable[[Sequence[Any]], Sequence[Any] | Awaitable[Sequence[Any]]]
|
||||
@@ -20,8 +24,10 @@ async def paginate(
|
||||
*,
|
||||
transformer: Transformer | None = None,
|
||||
) -> DefaultLimitOffsetPage[T]:
|
||||
# fastapi-pagination is not fully typed (it returns Any), but response_model validation
|
||||
# ensures runtime correctness. Centralize casts here to keep strict mypy clean.
|
||||
"""Execute a paginated query and cast to the project page type alias."""
|
||||
# fastapi-pagination is not fully typed (it returns Any), but response_model
|
||||
# validation ensures runtime correctness. Centralize casts here to keep strict
|
||||
# mypy clean.
|
||||
return cast(
|
||||
DefaultLimitOffsetPage[T],
|
||||
await _paginate(session, statement, transformer=transformer),
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Model manager descriptor utilities for query-set style access."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import SQLModel, col
|
||||
@@ -13,41 +15,55 @@ ModelT = TypeVar("ModelT", bound=SQLModel)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelManager(Generic[ModelT]):
|
||||
"""Convenience query manager bound to a SQLModel class."""
|
||||
|
||||
model: type[ModelT]
|
||||
id_field: str = "id"
|
||||
|
||||
def all(self) -> QuerySet[ModelT]:
|
||||
"""Return an unfiltered queryset for the bound model."""
|
||||
return qs(self.model)
|
||||
|
||||
def none(self) -> QuerySet[ModelT]:
|
||||
"""Return a queryset that yields no rows."""
|
||||
return qs(self.model).filter(false())
|
||||
|
||||
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
|
||||
def filter(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by SQL criteria expressions."""
|
||||
return self.all().filter(*criteria)
|
||||
|
||||
def where(self, *criteria: Any) -> QuerySet[ModelT]:
|
||||
def where(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Alias for `filter`."""
|
||||
return self.filter(*criteria)
|
||||
|
||||
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
|
||||
def filter_by(self, **kwargs: object) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by model field equality values."""
|
||||
queryset = self.all()
|
||||
for field_name, value in kwargs.items():
|
||||
queryset = queryset.filter(col(getattr(self.model, field_name)) == value)
|
||||
return queryset
|
||||
|
||||
def by_id(self, obj_id: Any) -> QuerySet[ModelT]:
|
||||
def by_id(self, obj_id: object) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by primary identifier field."""
|
||||
return self.by_field(self.id_field, obj_id)
|
||||
|
||||
def by_ids(self, obj_ids: list[Any] | tuple[Any, ...] | set[Any]) -> QuerySet[ModelT]:
|
||||
def by_ids(
|
||||
self,
|
||||
obj_ids: list[object] | tuple[object, ...] | set[object],
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by a set/list/tuple of identifiers."""
|
||||
return self.by_field_in(self.id_field, obj_ids)
|
||||
|
||||
def by_field(self, field_name: str, value: Any) -> QuerySet[ModelT]:
|
||||
def by_field(self, field_name: str, value: object) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by a single field equality check."""
|
||||
return self.filter(col(getattr(self.model, field_name)) == value)
|
||||
|
||||
def by_field_in(
|
||||
self,
|
||||
field_name: str,
|
||||
values: list[Any] | tuple[Any, ...] | set[Any],
|
||||
values: list[object] | tuple[object, ...] | set[object],
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by `field IN values` semantics."""
|
||||
seq = tuple(values)
|
||||
if not seq:
|
||||
return self.none()
|
||||
@@ -55,5 +71,8 @@ class ModelManager(Generic[ModelT]):
|
||||
|
||||
|
||||
class ManagerDescriptor(Generic[ModelT]):
|
||||
"""Descriptor that exposes a model-bound `ModelManager` as `.objects`."""
|
||||
|
||||
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
|
||||
"""Return a fresh manager bound to the owning model class."""
|
||||
return ModelManager(owner)
|
||||
|
||||
@@ -1,50 +1,67 @@
|
||||
"""Lightweight immutable query-set wrapper for SQLModel statements."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
ModelT = TypeVar("ModelT")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QuerySet(Generic[ModelT]):
|
||||
"""Composable immutable wrapper around a SQLModel scalar select statement."""
|
||||
|
||||
statement: SelectOfScalar[ModelT]
|
||||
|
||||
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
|
||||
def filter(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with additional SQL criteria."""
|
||||
return replace(self, statement=self.statement.where(*criteria))
|
||||
|
||||
def where(self, *criteria: Any) -> QuerySet[ModelT]:
|
||||
def where(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Alias for `filter` to mirror SQLAlchemy naming."""
|
||||
return self.filter(*criteria)
|
||||
|
||||
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
|
||||
def filter_by(self, **kwargs: object) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset filtered by keyword-equality criteria."""
|
||||
statement = self.statement.filter_by(**kwargs)
|
||||
return replace(self, statement=statement)
|
||||
|
||||
def order_by(self, *ordering: Any) -> QuerySet[ModelT]:
|
||||
def order_by(self, *ordering: object) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with ordering clauses applied."""
|
||||
return replace(self, statement=self.statement.order_by(*ordering))
|
||||
|
||||
def limit(self, value: int) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with a SQL row limit."""
|
||||
return replace(self, statement=self.statement.limit(value))
|
||||
|
||||
def offset(self, value: int) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with a SQL row offset."""
|
||||
return replace(self, statement=self.statement.offset(value))
|
||||
|
||||
async def all(self, session: AsyncSession) -> list[ModelT]:
|
||||
"""Execute and return all rows for the current queryset."""
|
||||
return list(await session.exec(self.statement))
|
||||
|
||||
async def first(self, session: AsyncSession) -> ModelT | None:
|
||||
"""Execute and return the first row, if available."""
|
||||
return (await session.exec(self.statement)).first()
|
||||
|
||||
async def one_or_none(self, session: AsyncSession) -> ModelT | None:
|
||||
"""Execute and return one row or `None`."""
|
||||
return (await session.exec(self.statement)).one_or_none()
|
||||
|
||||
async def exists(self, session: AsyncSession) -> bool:
|
||||
"""Return whether the queryset yields at least one row."""
|
||||
return await self.limit(1).first(session) is not None
|
||||
|
||||
|
||||
def qs(model: type[ModelT]) -> QuerySet[ModelT]:
|
||||
"""Create a base queryset for a SQLModel class."""
|
||||
return QuerySet(select(model))
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Database engine, session factory, and startup migration helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import anyio
|
||||
from alembic import command
|
||||
@@ -15,6 +17,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app import models as _models
|
||||
from app.core.config import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
# Import model modules so SQLModel metadata is fully registered at startup.
|
||||
_MODEL_REGISTRY = _models
|
||||
|
||||
@@ -48,12 +53,14 @@ def _alembic_config() -> Config:
|
||||
|
||||
|
||||
def run_migrations() -> None:
|
||||
"""Apply Alembic migrations to the latest revision."""
|
||||
logger.info("Running database migrations.")
|
||||
command.upgrade(_alembic_config(), "head")
|
||||
logger.info("Database migrations complete.")
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Initialize database schema, running migrations when configured."""
|
||||
if settings.db_auto_migrate:
|
||||
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
|
||||
if any(versions_dir.glob("*.py")):
|
||||
@@ -67,6 +74,7 @@ async def init_db() -> None:
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Yield a request-scoped async DB session with safe rollback on errors."""
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
yield session
|
||||
|
||||
Reference in New Issue
Block a user