refactor: update migration paths and improve database operation handling
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
ModelT = TypeVar("ModelT", bound=SQLModel)
|
||||
|
||||
@@ -18,15 +19,19 @@ class MultipleObjectsReturned(LookupError):
|
||||
pass
|
||||
|
||||
|
||||
def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]:
|
||||
stmt = select(model)
|
||||
for key, value in lookup.items():
|
||||
stmt = stmt.where(getattr(model, key) == value)
|
||||
return stmt
|
||||
|
||||
|
||||
async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None:
|
||||
return await session.get(model, obj_id)
|
||||
|
||||
|
||||
async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT:
|
||||
stmt = select(model)
|
||||
for key, value in lookup.items():
|
||||
stmt = stmt.where(getattr(model, key) == value)
|
||||
stmt = stmt.limit(2)
|
||||
stmt = _lookup_statement(model, lookup).limit(2)
|
||||
items = (await session.exec(stmt)).all()
|
||||
if not items:
|
||||
raise DoesNotExist(f"{model.__name__} matching query does not exist.")
|
||||
@@ -38,9 +43,7 @@ async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> Mode
|
||||
|
||||
|
||||
async def get_one_by(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT | None:
|
||||
stmt = select(model)
|
||||
for key, value in lookup.items():
|
||||
stmt = stmt.where(getattr(model, key) == value)
|
||||
stmt = _lookup_statement(model, lookup)
|
||||
return (await session.exec(stmt)).first()
|
||||
|
||||
|
||||
@@ -84,6 +87,64 @@ async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) ->
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def list_by(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
*,
|
||||
order_by: Iterable[Any] = (),
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
**lookup: Any,
|
||||
) -> list[ModelT]:
|
||||
stmt = _lookup_statement(model, lookup)
|
||||
for ordering in order_by:
|
||||
stmt = stmt.order_by(ordering)
|
||||
if offset is not None:
|
||||
stmt = stmt.offset(offset)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
return list(await session.exec(stmt))
|
||||
|
||||
|
||||
async def exists(session: AsyncSession, model: type[ModelT], **lookup: Any) -> bool:
|
||||
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
|
||||
|
||||
|
||||
def apply_updates(
|
||||
obj: ModelT,
|
||||
updates: Mapping[str, Any],
|
||||
*,
|
||||
exclude_none: bool = False,
|
||||
allowed_fields: set[str] | None = None,
|
||||
) -> ModelT:
|
||||
for key, value in updates.items():
|
||||
if allowed_fields is not None and key not in allowed_fields:
|
||||
continue
|
||||
if exclude_none and value is None:
|
||||
continue
|
||||
setattr(obj, key, value)
|
||||
return obj
|
||||
|
||||
|
||||
async def patch(
|
||||
session: AsyncSession,
|
||||
obj: ModelT,
|
||||
updates: Mapping[str, Any],
|
||||
*,
|
||||
exclude_none: bool = False,
|
||||
allowed_fields: set[str] | None = None,
|
||||
commit: bool = True,
|
||||
refresh: bool = True,
|
||||
) -> ModelT:
|
||||
apply_updates(
|
||||
obj,
|
||||
updates,
|
||||
exclude_none=exclude_none,
|
||||
allowed_fields=allowed_fields,
|
||||
)
|
||||
return await save(session, obj, commit=commit, refresh=refresh)
|
||||
|
||||
|
||||
async def get_or_create(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
@@ -93,9 +154,7 @@ async def get_or_create(
|
||||
refresh: bool = True,
|
||||
**lookup: Any,
|
||||
) -> tuple[ModelT, bool]:
|
||||
stmt = select(model)
|
||||
for key, value in lookup.items():
|
||||
stmt = stmt.where(getattr(model, key) == value)
|
||||
stmt = _lookup_statement(model, lookup)
|
||||
|
||||
existing = (await session.exec(stmt)).first()
|
||||
if existing is not None:
|
||||
|
||||
43
backend/app/db/queryset.py
Normal file
43
backend/app/db/queryset.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
ModelT = TypeVar("ModelT")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QuerySet(Generic[ModelT]):
|
||||
statement: SelectOfScalar[ModelT]
|
||||
|
||||
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
|
||||
return replace(self, statement=self.statement.where(*criteria))
|
||||
|
||||
def order_by(self, *ordering: Any) -> QuerySet[ModelT]:
|
||||
return replace(self, statement=self.statement.order_by(*ordering))
|
||||
|
||||
def limit(self, value: int) -> QuerySet[ModelT]:
|
||||
return replace(self, statement=self.statement.limit(value))
|
||||
|
||||
def offset(self, value: int) -> QuerySet[ModelT]:
|
||||
return replace(self, statement=self.statement.offset(value))
|
||||
|
||||
async def all(self, session: AsyncSession) -> list[ModelT]:
|
||||
return list(await session.exec(self.statement))
|
||||
|
||||
async def first(self, session: AsyncSession) -> ModelT | None:
|
||||
return (await session.exec(self.statement)).first()
|
||||
|
||||
async def one_or_none(self, session: AsyncSession) -> ModelT | None:
|
||||
return (await session.exec(self.statement)).one_or_none()
|
||||
|
||||
async def exists(self, session: AsyncSession) -> bool:
|
||||
return await self.limit(1).first(session) is not None
|
||||
|
||||
|
||||
def qs(model: type[ModelT]) -> QuerySet[ModelT]:
|
||||
return QuerySet(select(model))
|
||||
@@ -5,12 +5,12 @@ from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from app import models # noqa: F401
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -51,12 +51,12 @@ def run_migrations() -> None:
|
||||
|
||||
async def init_db() -> None:
|
||||
if settings.db_auto_migrate:
|
||||
versions_dir = Path(__file__).resolve().parents[2] / "alembic" / "versions"
|
||||
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
|
||||
if any(versions_dir.glob("*.py")):
|
||||
logger.info("Running Alembic migrations on startup")
|
||||
logger.info("Running migrations on startup")
|
||||
await anyio.to_thread.run_sync(run_migrations)
|
||||
return
|
||||
logger.warning("No Alembic revisions found; falling back to create_all")
|
||||
logger.warning("No migration revisions found; falling back to create_all")
|
||||
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
9
backend/app/db/sqlmodel_exec.py
Normal file
9
backend/app/db/sqlmodel_exec.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.sql.base import Executable
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
async def exec_dml(session: AsyncSession, statement: Executable) -> None:
|
||||
# SQLModel's AsyncSession typing only overloads exec() for SELECT statements.
|
||||
await session.exec(statement) # type: ignore[call-overload]
|
||||
Reference in New Issue
Block a user