refactor: replace generic Exception handling with SQLAlchemyError in CRUD and session management
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Any, TypeVar, cast
|
|||||||
|
|
||||||
from sqlalchemy import delete as sql_delete
|
from sqlalchemy import delete as sql_delete
|
||||||
from sqlalchemy import update as sql_update
|
from sqlalchemy import update as sql_update
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||||
from sqlmodel import SQLModel, select
|
from sqlmodel import SQLModel, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlmodel.sql.expression import SelectOfScalar
|
from sqlmodel.sql.expression import SelectOfScalar
|
||||||
@@ -24,7 +24,7 @@ class MultipleObjectsReturned(LookupError):
|
|||||||
async def _flush_or_rollback(session: AsyncSession) -> None:
|
async def _flush_or_rollback(session: AsyncSession) -> None:
|
||||||
try:
|
try:
|
||||||
await session.flush()
|
await session.flush()
|
||||||
except Exception:
|
except SQLAlchemyError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ async def _flush_or_rollback(session: AsyncSession) -> None:
|
|||||||
async def _commit_or_rollback(session: AsyncSession) -> None:
|
async def _commit_or_rollback(session: AsyncSession) -> None:
|
||||||
try:
|
try:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
except Exception:
|
except SQLAlchemyError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -268,7 +268,7 @@ async def get_or_create(
|
|||||||
if existing is not None:
|
if existing is not None:
|
||||||
return existing, False
|
return existing, False
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except SQLAlchemyError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
|||||||
import anyio
|
import anyio
|
||||||
from alembic import command
|
from alembic import command
|
||||||
from alembic.config import Config
|
from alembic.config import Config
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -69,9 +70,15 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
async with async_session_maker() as session:
|
async with async_session_maker() as session:
|
||||||
try:
|
try:
|
||||||
yield session
|
yield session
|
||||||
except Exception:
|
finally:
|
||||||
|
try:
|
||||||
|
in_txn = bool(session.in_transaction())
|
||||||
|
except SQLAlchemyError:
|
||||||
|
logger.exception("Failed to inspect session transaction state.")
|
||||||
|
return
|
||||||
|
if not in_txn:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
except Exception:
|
except SQLAlchemyError:
|
||||||
logger.exception("Failed to rollback session after request error.")
|
logger.exception("Failed to rollback session after request error.")
|
||||||
raise
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from types import SimpleNamespace
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlmodel import Field, SQLModel
|
from sqlmodel import Field, SQLModel
|
||||||
|
|
||||||
from app.db import crud
|
from app.db import crud
|
||||||
@@ -43,12 +44,17 @@ class _Maker:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
fake_session = SimpleNamespace(rollbacks=0)
|
@dataclass
|
||||||
|
class _FakeDependencySession:
|
||||||
|
rollbacks: int = 0
|
||||||
|
|
||||||
async def _rollback() -> None:
|
def in_transaction(self) -> bool:
|
||||||
fake_session.rollbacks += 1
|
return True
|
||||||
|
|
||||||
fake_session.rollback = _rollback
|
async def rollback(self) -> None:
|
||||||
|
self.rollbacks += 1
|
||||||
|
|
||||||
|
fake_session = _FakeDependencySession()
|
||||||
ctx = _SessionCtx(fake_session)
|
ctx = _SessionCtx(fake_session)
|
||||||
monkeypatch.setattr(db_session, "async_session_maker", _Maker(ctx))
|
monkeypatch.setattr(db_session, "async_session_maker", _Maker(ctx))
|
||||||
|
|
||||||
@@ -66,6 +72,9 @@ async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.Mo
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_rolls_back_when_commit_fails() -> None:
|
async def test_create_rolls_back_when_commit_fails() -> None:
|
||||||
|
class _CommitError(SQLAlchemyError):
|
||||||
|
pass
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _FailCommitSession:
|
class _FailCommitSession:
|
||||||
rollback_calls: int = 0
|
rollback_calls: int = 0
|
||||||
@@ -82,7 +91,7 @@ async def test_create_rolls_back_when_commit_fails() -> None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def commit(self) -> None:
|
async def commit(self) -> None:
|
||||||
raise RuntimeError("commit failed")
|
raise _CommitError("commit failed")
|
||||||
|
|
||||||
async def rollback(self) -> None:
|
async def rollback(self) -> None:
|
||||||
self.rollback_calls += 1
|
self.rollback_calls += 1
|
||||||
@@ -92,7 +101,7 @@ async def test_create_rolls_back_when_commit_fails() -> None:
|
|||||||
|
|
||||||
session = _FailCommitSession()
|
session = _FailCommitSession()
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="commit failed"):
|
with pytest.raises(SQLAlchemyError, match="commit failed"):
|
||||||
await crud.create(session, _Thing, name="demo")
|
await crud.create(session, _Thing, name="demo")
|
||||||
|
|
||||||
assert session.rollback_calls == 1
|
assert session.rollback_calls == 1
|
||||||
@@ -101,6 +110,9 @@ async def test_create_rolls_back_when_commit_fails() -> None:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_where_rolls_back_when_commit_fails() -> None:
|
async def test_delete_where_rolls_back_when_commit_fails() -> None:
|
||||||
|
class _CommitError(SQLAlchemyError):
|
||||||
|
pass
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _FailCommitDmlSession:
|
class _FailCommitDmlSession:
|
||||||
rollback_calls: int = 0
|
rollback_calls: int = 0
|
||||||
@@ -111,14 +123,14 @@ async def test_delete_where_rolls_back_when_commit_fails() -> None:
|
|||||||
return SimpleNamespace(rowcount=3)
|
return SimpleNamespace(rowcount=3)
|
||||||
|
|
||||||
async def commit(self) -> None:
|
async def commit(self) -> None:
|
||||||
raise RuntimeError("commit failed")
|
raise _CommitError("commit failed")
|
||||||
|
|
||||||
async def rollback(self) -> None:
|
async def rollback(self) -> None:
|
||||||
self.rollback_calls += 1
|
self.rollback_calls += 1
|
||||||
|
|
||||||
session = _FailCommitDmlSession()
|
session = _FailCommitDmlSession()
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="commit failed"):
|
with pytest.raises(SQLAlchemyError, match="commit failed"):
|
||||||
await crud.delete_where(session, _Thing, commit=True)
|
await crud.delete_where(session, _Thing, commit=True)
|
||||||
|
|
||||||
assert session.exec_calls == 1
|
assert session.exec_calls == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user