refactor: replace DefaultLimitOffsetPage with LimitOffsetPage in multiple files and update timezone handling to use UTC
This commit is contained in:
@@ -132,7 +132,7 @@ async def save(
|
||||
return obj
|
||||
|
||||
|
||||
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None:
|
||||
async def delete(session: AsyncSession, obj: SQLModel, *, commit: bool = True) -> None:
|
||||
"""Delete an object with optional commit."""
|
||||
await session.delete(obj)
|
||||
if commit:
|
||||
|
||||
@@ -3,19 +3,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from fastapi_pagination.ext.sqlalchemy import paginate as _paginate
|
||||
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
Transformer = Callable[[Sequence[Any]], Sequence[Any] | Awaitable[Sequence[Any]]]
|
||||
Transformer = Callable[
|
||||
[Sequence[Any]],
|
||||
Sequence[Any] | Awaitable[Sequence[Any]],
|
||||
]
|
||||
|
||||
|
||||
async def paginate(
|
||||
@@ -23,12 +27,7 @@ async def paginate(
|
||||
statement: Select[Any] | SelectOfScalar[Any],
|
||||
*,
|
||||
transformer: Transformer | None = None,
|
||||
) -> DefaultLimitOffsetPage[T]:
|
||||
) -> LimitOffsetPage[T]:
|
||||
"""Execute a paginated query and cast to the project page type alias."""
|
||||
# fastapi-pagination is not fully typed (it returns Any), but response_model
|
||||
# validation ensures runtime correctness. Centralize casts here to keep strict
|
||||
# mypy clean.
|
||||
return cast(
|
||||
DefaultLimitOffsetPage[T],
|
||||
await _paginate(session, statement, transformer=transformer),
|
||||
)
|
||||
page = await _paginate(session, statement, transformer=transformer)
|
||||
return DefaultLimitOffsetPage[T].model_validate(page)
|
||||
|
||||
@@ -13,6 +13,8 @@ from app.db.queryset import QuerySet, qs
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
ModelT = TypeVar("ModelT", bound=SQLModel)
|
||||
|
||||
|
||||
@@ -31,11 +33,17 @@ class ModelManager(Generic[ModelT]):
|
||||
"""Return a queryset that yields no rows."""
|
||||
return qs(self.model).filter(false())
|
||||
|
||||
def filter(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
def filter(
|
||||
self,
|
||||
*criteria: ColumnElement[bool] | bool,
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by SQL criteria expressions."""
|
||||
return self.all().filter(*criteria)
|
||||
|
||||
def where(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
def where(
|
||||
self,
|
||||
*criteria: ColumnElement[bool] | bool,
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Alias for `filter`."""
|
||||
return self.filter(*criteria)
|
||||
|
||||
@@ -76,6 +84,7 @@ class ModelManager(Generic[ModelT]):
|
||||
class ManagerDescriptor(Generic[ModelT]):
|
||||
"""Descriptor that exposes a model-bound `ModelManager` as `.objects`."""
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
|
||||
"""Return a fresh manager bound to the owning model class."""
|
||||
return ModelManager(owner)
|
||||
|
||||
@@ -3,11 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
@@ -20,15 +22,18 @@ class QuerySet(Generic[ModelT]):
|
||||
|
||||
statement: SelectOfScalar[ModelT]
|
||||
|
||||
def filter(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
def filter(
|
||||
self,
|
||||
*criteria: ColumnElement[bool] | bool,
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with additional SQL criteria."""
|
||||
statement = cast(
|
||||
"SelectOfScalar[ModelT]",
|
||||
cast(Any, self.statement).where(*criteria),
|
||||
)
|
||||
statement = self.statement.where(*criteria)
|
||||
return replace(self, statement=statement)
|
||||
|
||||
def where(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
def where(
|
||||
self,
|
||||
*criteria: ColumnElement[bool] | bool,
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Alias for `filter` to mirror SQLAlchemy naming."""
|
||||
return self.filter(*criteria)
|
||||
|
||||
@@ -37,12 +42,12 @@ class QuerySet(Generic[ModelT]):
|
||||
statement = self.statement.filter_by(**kwargs)
|
||||
return replace(self, statement=statement)
|
||||
|
||||
def order_by(self, *ordering: object) -> QuerySet[ModelT]:
|
||||
def order_by(
|
||||
self,
|
||||
*ordering: Mapped[Any] | ColumnElement[Any] | str,
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with ordering clauses applied."""
|
||||
statement = cast(
|
||||
"SelectOfScalar[ModelT]",
|
||||
cast(Any, self.statement).order_by(*ordering),
|
||||
)
|
||||
statement = self.statement.order_by(*ordering)
|
||||
return replace(self, statement=statement)
|
||||
|
||||
def limit(self, value: int) -> QuerySet[ModelT]:
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import anyio
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -65,11 +65,11 @@ async def init_db() -> None:
|
||||
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
|
||||
if any(versions_dir.glob("*.py")):
|
||||
logger.info("Running migrations on startup")
|
||||
await anyio.to_thread.run_sync(run_migrations)
|
||||
await asyncio.to_thread(run_migrations)
|
||||
return
|
||||
logger.warning("No migration revisions found; falling back to create_all")
|
||||
|
||||
async with async_engine.begin() as conn:
|
||||
async with async_engine.connect() as conn, conn.begin():
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user