refactor: update migration paths and improve database operation handling

This commit is contained in:
Abhimanyu Saharan
2026-02-09 00:51:26 +05:30
parent 8c4bcca603
commit f6bcd1ca5f
43 changed files with 1175 additions and 1445 deletions

View File

@@ -1,11 +1,12 @@
from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from typing import Any, TypeVar
from sqlalchemy.exc import IntegrityError
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
ModelT = TypeVar("ModelT", bound=SQLModel)
@@ -18,15 +19,19 @@ class MultipleObjectsReturned(LookupError):
pass
def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]:
stmt = select(model)
for key, value in lookup.items():
stmt = stmt.where(getattr(model, key) == value)
return stmt
async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None:
return await session.get(model, obj_id)
async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT:
stmt = select(model)
for key, value in lookup.items():
stmt = stmt.where(getattr(model, key) == value)
stmt = stmt.limit(2)
stmt = _lookup_statement(model, lookup).limit(2)
items = (await session.exec(stmt)).all()
if not items:
raise DoesNotExist(f"{model.__name__} matching query does not exist.")
@@ -38,9 +43,7 @@ async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> Mode
async def get_one_by(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT | None:
stmt = select(model)
for key, value in lookup.items():
stmt = stmt.where(getattr(model, key) == value)
stmt = _lookup_statement(model, lookup)
return (await session.exec(stmt)).first()
@@ -84,6 +87,64 @@ async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) ->
await session.commit()
async def list_by(
session: AsyncSession,
model: type[ModelT],
*,
order_by: Iterable[Any] = (),
limit: int | None = None,
offset: int | None = None,
**lookup: Any,
) -> list[ModelT]:
stmt = _lookup_statement(model, lookup)
for ordering in order_by:
stmt = stmt.order_by(ordering)
if offset is not None:
stmt = stmt.offset(offset)
if limit is not None:
stmt = stmt.limit(limit)
return list(await session.exec(stmt))
async def exists(session: AsyncSession, model: type[ModelT], **lookup: Any) -> bool:
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
def apply_updates(
obj: ModelT,
updates: Mapping[str, Any],
*,
exclude_none: bool = False,
allowed_fields: set[str] | None = None,
) -> ModelT:
for key, value in updates.items():
if allowed_fields is not None and key not in allowed_fields:
continue
if exclude_none and value is None:
continue
setattr(obj, key, value)
return obj
async def patch(
session: AsyncSession,
obj: ModelT,
updates: Mapping[str, Any],
*,
exclude_none: bool = False,
allowed_fields: set[str] | None = None,
commit: bool = True,
refresh: bool = True,
) -> ModelT:
apply_updates(
obj,
updates,
exclude_none=exclude_none,
allowed_fields=allowed_fields,
)
return await save(session, obj, commit=commit, refresh=refresh)
async def get_or_create(
session: AsyncSession,
model: type[ModelT],
@@ -93,9 +154,7 @@ async def get_or_create(
refresh: bool = True,
**lookup: Any,
) -> tuple[ModelT, bool]:
stmt = select(model)
for key, value in lookup.items():
stmt = stmt.where(getattr(model, key) == value)
stmt = _lookup_statement(model, lookup)
existing = (await session.exec(stmt)).first()
if existing is not None:

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
from dataclasses import dataclass, replace
from typing import Any, Generic, TypeVar
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
ModelT = TypeVar("ModelT")
@dataclass(frozen=True)
class QuerySet(Generic[ModelT]):
statement: SelectOfScalar[ModelT]
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
return replace(self, statement=self.statement.where(*criteria))
def order_by(self, *ordering: Any) -> QuerySet[ModelT]:
return replace(self, statement=self.statement.order_by(*ordering))
def limit(self, value: int) -> QuerySet[ModelT]:
return replace(self, statement=self.statement.limit(value))
def offset(self, value: int) -> QuerySet[ModelT]:
return replace(self, statement=self.statement.offset(value))
async def all(self, session: AsyncSession) -> list[ModelT]:
return list(await session.exec(self.statement))
async def first(self, session: AsyncSession) -> ModelT | None:
return (await session.exec(self.statement)).first()
async def one_or_none(self, session: AsyncSession) -> ModelT | None:
return (await session.exec(self.statement)).one_or_none()
async def exists(self, session: AsyncSession) -> bool:
return await self.limit(1).first(session) is not None
def qs(model: type[ModelT]) -> QuerySet[ModelT]:
return QuerySet(select(model))

View File

@@ -5,12 +5,12 @@ from collections.abc import AsyncGenerator
from pathlib import Path
import anyio
from alembic import command
from alembic.config import Config
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from alembic import command
from alembic.config import Config
from app import models # noqa: F401
from app.core.config import settings
@@ -51,12 +51,12 @@ def run_migrations() -> None:
async def init_db() -> None:
if settings.db_auto_migrate:
versions_dir = Path(__file__).resolve().parents[2] / "alembic" / "versions"
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
if any(versions_dir.glob("*.py")):
logger.info("Running Alembic migrations on startup")
logger.info("Running migrations on startup")
await anyio.to_thread.run_sync(run_migrations)
return
logger.warning("No Alembic revisions found; falling back to create_all")
logger.warning("No migration revisions found; falling back to create_all")
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

View File

@@ -0,0 +1,9 @@
from __future__ import annotations
from sqlalchemy.sql.base import Executable
from sqlmodel.ext.asyncio.session import AsyncSession
async def exec_dml(session: AsyncSession, statement: Executable) -> None:
# SQLModel's AsyncSession typing only overloads exec() for SELECT statements.
await session.exec(statement) # type: ignore[call-overload]