refactor: replace generic Exception handling with SQLAlchemyError in CRUD and session management

This commit is contained in:
Abhimanyu Saharan
2026-02-09 02:24:16 +05:30
parent fafcac1e16
commit 9340a74c42
3 changed files with 34 additions and 15 deletions

View File

@@ -5,7 +5,7 @@ from typing import Any, TypeVar, cast
from sqlalchemy import delete as sql_delete from sqlalchemy import delete as sql_delete
from sqlalchemy import update as sql_update from sqlalchemy import update as sql_update
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlmodel import SQLModel, select from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar from sqlmodel.sql.expression import SelectOfScalar
@@ -24,7 +24,7 @@ class MultipleObjectsReturned(LookupError):
async def _flush_or_rollback(session: AsyncSession) -> None: async def _flush_or_rollback(session: AsyncSession) -> None:
try: try:
await session.flush() await session.flush()
except Exception: except SQLAlchemyError:
await session.rollback() await session.rollback()
raise raise
@@ -32,7 +32,7 @@ async def _flush_or_rollback(session: AsyncSession) -> None:
async def _commit_or_rollback(session: AsyncSession) -> None: async def _commit_or_rollback(session: AsyncSession) -> None:
try: try:
await session.commit() await session.commit()
except Exception: except SQLAlchemyError:
await session.rollback() await session.rollback()
raise raise
@@ -268,7 +268,7 @@ async def get_or_create(
if existing is not None: if existing is not None:
return existing, False return existing, False
raise raise
except Exception: except SQLAlchemyError:
await session.rollback() await session.rollback()
raise raise

View File

@@ -7,6 +7,7 @@ from pathlib import Path
import anyio import anyio
from alembic import command from alembic import command
from alembic.config import Config from alembic.config import Config
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlmodel import SQLModel from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -69,9 +70,15 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session: async with async_session_maker() as session:
try: try:
yield session yield session
except Exception: finally:
try:
in_txn = bool(session.in_transaction())
except SQLAlchemyError:
logger.exception("Failed to inspect session transaction state.")
return
if not in_txn:
return
try: try:
await session.rollback() await session.rollback()
except Exception: except SQLAlchemyError:
logger.exception("Failed to rollback session after request error.") logger.exception("Failed to rollback session after request error.")
raise

View File

@@ -5,6 +5,7 @@ from types import SimpleNamespace
from typing import Any from typing import Any
import pytest import pytest
from sqlalchemy.exc import SQLAlchemyError
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
from app.db import crud from app.db import crud
@@ -43,12 +44,17 @@ class _Maker:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.MonkeyPatch) -> None: async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.MonkeyPatch) -> None:
fake_session = SimpleNamespace(rollbacks=0) @dataclass
class _FakeDependencySession:
rollbacks: int = 0
async def _rollback() -> None: def in_transaction(self) -> bool:
fake_session.rollbacks += 1 return True
fake_session.rollback = _rollback async def rollback(self) -> None:
self.rollbacks += 1
fake_session = _FakeDependencySession()
ctx = _SessionCtx(fake_session) ctx = _SessionCtx(fake_session)
monkeypatch.setattr(db_session, "async_session_maker", _Maker(ctx)) monkeypatch.setattr(db_session, "async_session_maker", _Maker(ctx))
@@ -66,6 +72,9 @@ async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.Mo
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_rolls_back_when_commit_fails() -> None: async def test_create_rolls_back_when_commit_fails() -> None:
class _CommitError(SQLAlchemyError):
pass
@dataclass @dataclass
class _FailCommitSession: class _FailCommitSession:
rollback_calls: int = 0 rollback_calls: int = 0
@@ -82,7 +91,7 @@ async def test_create_rolls_back_when_commit_fails() -> None:
return None return None
async def commit(self) -> None: async def commit(self) -> None:
raise RuntimeError("commit failed") raise _CommitError("commit failed")
async def rollback(self) -> None: async def rollback(self) -> None:
self.rollback_calls += 1 self.rollback_calls += 1
@@ -92,7 +101,7 @@ async def test_create_rolls_back_when_commit_fails() -> None:
session = _FailCommitSession() session = _FailCommitSession()
with pytest.raises(RuntimeError, match="commit failed"): with pytest.raises(SQLAlchemyError, match="commit failed"):
await crud.create(session, _Thing, name="demo") await crud.create(session, _Thing, name="demo")
assert session.rollback_calls == 1 assert session.rollback_calls == 1
@@ -101,6 +110,9 @@ async def test_create_rolls_back_when_commit_fails() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_where_rolls_back_when_commit_fails() -> None: async def test_delete_where_rolls_back_when_commit_fails() -> None:
class _CommitError(SQLAlchemyError):
pass
@dataclass @dataclass
class _FailCommitDmlSession: class _FailCommitDmlSession:
rollback_calls: int = 0 rollback_calls: int = 0
@@ -111,14 +123,14 @@ async def test_delete_where_rolls_back_when_commit_fails() -> None:
return SimpleNamespace(rowcount=3) return SimpleNamespace(rowcount=3)
async def commit(self) -> None: async def commit(self) -> None:
raise RuntimeError("commit failed") raise _CommitError("commit failed")
async def rollback(self) -> None: async def rollback(self) -> None:
self.rollback_calls += 1 self.rollback_calls += 1
session = _FailCommitDmlSession() session = _FailCommitDmlSession()
with pytest.raises(RuntimeError, match="commit failed"): with pytest.raises(SQLAlchemyError, match="commit failed"):
await crud.delete_where(session, _Thing, commit=True) await crud.delete_where(session, _Thing, commit=True)
assert session.exec_calls == 1 assert session.exec_calls == 1