From 228b99bc9b6cc8fcc7fac6c2e4f82c508ea51cfa Mon Sep 17 00:00:00 2001 From: Abhimanyu Saharan Date: Mon, 9 Feb 2026 02:04:14 +0530 Subject: [PATCH] refactor: replace SQLModel with QueryModel in various models and update query methods --- backend/app/api/agent.py | 24 ++- backend/app/api/agents.py | 24 ++- backend/app/api/approvals.py | 15 +- backend/app/api/board_group_memory.py | 53 ++++--- backend/app/api/board_groups.py | 18 +-- backend/app/api/board_memory.py | 28 ++-- backend/app/api/board_onboarding.py | 66 ++++---- backend/app/api/boards.py | 4 +- backend/app/api/deps.py | 14 +- backend/app/api/gateway.py | 6 +- backend/app/api/gateways.py | 40 +++-- backend/app/api/organizations.py | 143 ++++++++---------- backend/app/api/tasks.py | 61 ++++---- backend/app/db/crud.py | 3 +- backend/app/db/query_manager.py | 59 ++++++++ backend/app/db/queryset.py | 7 + backend/app/models/activity_events.py | 5 +- backend/app/models/agents.py | 5 +- backend/app/models/approvals.py | 5 +- backend/app/models/base.py | 11 ++ backend/app/models/board_group_memory.py | 5 +- backend/app/models/board_memory.py | 5 +- backend/app/models/board_onboarding.py | 5 +- backend/app/models/gateways.py | 5 +- .../app/models/organization_board_access.py | 5 +- .../organization_invite_board_access.py | 5 +- backend/app/models/organization_invites.py | 5 +- backend/app/models/organization_members.py | 5 +- backend/app/models/organizations.py | 5 +- backend/app/models/task_fingerprints.py | 5 +- backend/app/models/tenancy.py | 4 +- backend/app/models/users.py | 6 +- backend/app/queries/__init__.py | 1 - backend/app/queries/organizations.py | 50 ------ backend/app/services/board_group_snapshot.py | 4 +- backend/app/services/board_snapshot.py | 48 +++--- backend/app/services/organizations.py | 24 ++- backend/app/services/template_sync.py | 20 ++- .../test_organizations_member_remove_api.py | 30 ++-- .../test_task_dependencies_integration.py | 4 +- 40 files changed, 413 insertions(+), 419 deletions(-) create mode 100644 backend/app/db/query_manager.py create mode 100644 backend/app/models/base.py delete mode 100644 backend/app/queries/__init__.py delete mode 100644 backend/app/queries/organizations.py diff --git a/backend/app/api/agent.py b/backend/app/api/agent.py index 9fbc027..5f0068b 100644 --- a/backend/app/api/agent.py +++ b/backend/app/api/agent.py @@ -102,7 +102,7 @@ def _guard_board_access(agent_ctx: AgentAuthContext, board: Board) -> None: async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig: if not board.gateway_id: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) - gateway = await session.get(Gateway, board.gateway_id) + gateway = await Gateway.objects.by_id(board.gateway_id).first(session) if gateway is None or not gateway.url: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) return GatewayClientConfig(url=gateway.url, token=gateway.token) @@ -117,9 +117,7 @@ async def _require_gateway_main( raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Agent missing session key" ) - gateway = ( - await session.exec(select(Gateway).where(col(Gateway.main_session_key) == session_key)) - ).first() + gateway = await Gateway.objects.filter_by(main_session_key=session_key).first(session) if gateway is None: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -139,7 +137,7 @@ async def _require_gateway_board( gateway: Gateway, board_id: UUID | str, ) -> Board: - board = await session.get(Board, board_id) + board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") if board.gateway_id != gateway.id: @@ -254,7 +252,7 @@ async def create_task( }, ) if task.assigned_agent_id: - agent = await session.get(Agent, task.assigned_agent_id) + agent = await Agent.objects.by_id(task.assigned_agent_id).first(session) if agent is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if agent.is_board_lead: @@ -286,7 +284,7 @@ async def create_task( ) await session.commit() if task.assigned_agent_id: - assigned_agent = await session.get(Agent, task.assigned_agent_id) + assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session) if assigned_agent: await tasks_api._notify_agent_on_task_assign( session=session, @@ -466,7 +464,7 @@ async def nudge_agent( _guard_board_access(agent_ctx, board) if not agent_ctx.agent.is_board_lead: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - target = await session.get(Agent, agent_id) + target = await Agent.objects.by_id(agent_id).first(session) if target is None or (target.board_id and target.board_id != board.id): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if not target.openclaw_session_id: @@ -528,7 +526,7 @@ async def get_agent_soul( _guard_board_access(agent_ctx, board) if not agent_ctx.agent.is_board_lead and str(agent_ctx.agent.id) != agent_id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - target = await session.get(Agent, agent_id) + target = await Agent.objects.by_id(agent_id).first(session) if target is None or (target.board_id and target.board_id != board.id): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) config = await _gateway_config(session, board) @@ -566,7 +564,7 @@ async def update_agent_soul( _guard_board_access(agent_ctx, board) if not agent_ctx.agent.is_board_lead: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - target = await session.get(Agent, agent_id) + target = await Agent.objects.by_id(agent_id).first(session) if target is None or (target.board_id and target.board_id != board.id): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) config = await _gateway_config(session, board) @@ -629,7 +627,7 @@ async def ask_user_via_gateway_main( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Board is not attached to a gateway", ) - gateway = await session.get(Gateway, board.gateway_id) + gateway = await Gateway.objects.by_id(board.gateway_id).first(session) if gateway is None or not gateway.url: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, @@ -689,9 +687,7 @@ async def ask_user_via_gateway_main( agent_id=agent_ctx.agent.id, ) - main_agent = ( - await session.exec(select(Agent).where(col(Agent.openclaw_session_id) == main_session_key)) - ).first() + main_agent = await Agent.objects.filter_by(openclaw_session_id=main_session_key).first(session) await session.commit() diff --git a/backend/app/api/agents.py b/backend/app/api/agents.py index d44a83c..4bd5893 100644 --- a/backend/app/api/agents.py +++ b/backend/app/api/agents.py @@ -109,7 +109,7 @@ async def _require_board( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="board_id is required", ) - board = await session.get(Board, board_id) + board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") if user is not None: @@ -125,7 +125,7 @@ async def _require_gateway( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Board gateway_id is required", ) - gateway = await session.get(Gateway, board.gateway_id) + gateway = await Gateway.objects.by_id(board.gateway_id).first(session) if gateway is None: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, @@ -182,9 +182,7 @@ async def _find_gateway_for_main_session( ) -> Gateway | None: if not session_key: return None - return ( - await session.exec(select(Gateway).where(Gateway.main_session_key == session_key)) - ).first() + return await Gateway.objects.filter_by(main_session_key=session_key).first(session) async def _ensure_gateway_session( @@ -237,7 +235,7 @@ async def _require_user_context(session: AsyncSession, user: User | None) -> Org member = await get_active_membership(session, user) if member is None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - organization = await session.get(Organization, member.organization_id) + organization = await Organization.objects.by_id(member.organization_id).first(session) if organization is None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return OrganizationContext(organization=organization, member=member) @@ -258,7 +256,7 @@ async def _require_agent_access( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) return - board = await session.get(Board, agent.board_id) + board = await Board.objects.by_id(agent.board_id).first(session) if board is None or board.organization_id != ctx.organization.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if not await has_board_access(session, member=ctx.member, board=board, write=write): @@ -323,7 +321,7 @@ async def list_agents( if board_id is not None: statement = statement.where(col(Agent.board_id) == board_id) if gateway_id is not None: - gateway = await session.get(Gateway, gateway_id) + gateway = await Gateway.objects.by_id(gateway_id).first(session) if gateway is None or gateway.organization_id != ctx.organization.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) statement = statement.join(Board, col(Agent.board_id) == col(Board.id)).where( @@ -532,7 +530,7 @@ async def get_agent( session: AsyncSession = Depends(get_session), ctx: OrganizationContext = Depends(require_org_admin), ) -> AgentRead: - agent = await session.get(Agent, agent_id) + agent = await Agent.objects.by_id(agent_id).first(session) if agent is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) await _require_agent_access(session, agent=agent, ctx=ctx, write=False) @@ -549,7 +547,7 @@ async def update_agent( auth: AuthContext = Depends(get_auth_context), ctx: OrganizationContext = Depends(require_org_admin), ) -> AgentRead: - agent = await session.get(Agent, agent_id) + agent = await Agent.objects.by_id(agent_id).first(session) if agent is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) await _require_agent_access(session, agent=agent, ctx=ctx, write=True) @@ -728,7 +726,7 @@ async def heartbeat_agent( session: AsyncSession = Depends(get_session), actor: ActorContext = Depends(require_admin_or_agent), ) -> AgentRead: - agent = await session.get(Agent, agent_id) + agent = await Agent.objects.by_id(agent_id).first(session) if agent is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if actor.actor_type == "agent" and actor.agent and actor.agent.id != agent.id: @@ -767,7 +765,7 @@ async def heartbeat_or_create_agent( actor=actor, ) - statement = select(Agent).where(Agent.name == payload.name) + statement = Agent.objects.filter_by(name=payload.name).statement if payload.board_id is not None: statement = statement.where(Agent.board_id == payload.board_id) agent = (await session.exec(statement)).first() @@ -943,7 +941,7 @@ async def delete_agent( session: AsyncSession = Depends(get_session), ctx: OrganizationContext = Depends(require_org_admin), ) -> OkResponse: - agent = await session.get(Agent, agent_id) + agent = await Agent.objects.by_id(agent_id).first(session) if agent is None: return OkResponse() await _require_agent_access(session, agent=agent, ctx=ctx, write=True) diff --git a/backend/app/api/approvals.py b/backend/app/api/approvals.py index dcb99aa..a98f207 100644 --- a/backend/app/api/approvals.py +++ b/backend/app/api/approvals.py @@ -77,9 +77,8 @@ async def _fetch_approval_events( since: datetime, ) -> list[Approval]: statement = ( - select(Approval) - .where(col(Approval.board_id) == board_id) - .where( + Approval.objects.filter_by(board_id=board_id) + .filter( or_( col(Approval.created_at) >= since, col(Approval.resolved_at) >= since, @@ -87,7 +86,7 @@ async def _fetch_approval_events( ) .order_by(asc(col(Approval.created_at))) ) - return list(await session.exec(statement)) + return await statement.all(session) @router.get("", response_model=DefaultLimitOffsetPage[ApprovalRead]) @@ -97,11 +96,11 @@ async def list_approvals( session: AsyncSession = Depends(get_session), actor: ActorContext = Depends(require_admin_or_agent), ) -> DefaultLimitOffsetPage[ApprovalRead]: - statement = select(Approval).where(col(Approval.board_id) == board.id) + statement = Approval.objects.filter_by(board_id=board.id) if status_filter: - statement = statement.where(col(Approval.status) == status_filter) + statement = statement.filter(col(Approval.status) == status_filter) statement = statement.order_by(col(Approval.created_at).desc()) - return await paginate(session, statement) + return await paginate(session, statement.statement) @router.get("/stream") @@ -207,7 +206,7 @@ async def update_approval( board: Board = Depends(get_board_for_user_write), session: AsyncSession = Depends(get_session), ) -> Approval: - approval = await session.get(Approval, approval_id) + approval = await Approval.objects.by_id(approval_id).first(session) if approval is None or approval.board_id != board.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) updates = payload.model_dump(exclude_unset=True) diff --git a/backend/app/api/board_group_memory.py b/backend/app/api/board_group_memory.py index 975ebb1..ab1aade 100644 --- a/backend/app/api/board_group_memory.py +++ b/backend/app/api/board_group_memory.py @@ -8,7 +8,7 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from sqlalchemy import func -from sqlmodel import col, select +from sqlmodel import col from sqlmodel.ext.asyncio.session import AsyncSession from sse_starlette.sse import EventSourceResponse @@ -71,7 +71,7 @@ def _serialize_memory(memory: BoardGroupMemory) -> dict[str, object]: async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None: if board.gateway_id is None: return None - gateway = await session.get(Gateway, board.gateway_id) + gateway = await Gateway.objects.by_id(board.gateway_id).first(session) if gateway is None or not gateway.url: return None return GatewayClientConfig(url=gateway.url, token=gateway.token) @@ -96,17 +96,17 @@ async def _fetch_memory_events( is_chat: bool | None = None, ) -> list[BoardGroupMemory]: statement = ( - select(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == board_group_id) + BoardGroupMemory.objects.filter_by(board_group_id=board_group_id) # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to # satisfy the NonEmptyStr response schema. - .where(func.length(func.trim(col(BoardGroupMemory.content))) > 0) + .filter(func.length(func.trim(col(BoardGroupMemory.content))) > 0) ) if is_chat is not None: - statement = statement.where(col(BoardGroupMemory.is_chat) == is_chat) - statement = statement.where(col(BoardGroupMemory.created_at) >= since).order_by( + statement = statement.filter(col(BoardGroupMemory.is_chat) == is_chat) + statement = statement.filter(col(BoardGroupMemory.created_at) >= since).order_by( col(BoardGroupMemory.created_at) ) - return list(await session.exec(statement)) + return await statement.all(session) async def _require_group_access( @@ -116,7 +116,7 @@ async def _require_group_access( ctx: OrganizationContext, write: bool, ) -> BoardGroup: - group = await session.get(BoardGroup, group_id) + group = await BoardGroup.objects.by_id(group_id).first(session) if group is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if group.organization_id != ctx.member.organization_id: @@ -127,9 +127,9 @@ async def _require_group_access( if not write and member_all_boards_read(ctx.member): return group - board_ids = list( - await session.exec(select(Board.id).where(col(Board.board_group_id) == group_id)) - ) + board_ids = [ + board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session) + ] if not board_ids: if is_org_admin(ctx.member): return group @@ -156,12 +156,12 @@ async def _notify_group_memory_targets( is_broadcast = "broadcast" in tags or "all" in mentions # Fetch group boards + agents. - boards = list(await session.exec(select(Board).where(col(Board.board_group_id) == group.id))) + boards = await Board.objects.filter_by(board_group_id=group.id).all(session) if not boards: return board_by_id = {board.id: board for board in boards} board_ids = list(board_by_id.keys()) - agents = list(await session.exec(select(Agent).where(col(Agent.board_id).in_(board_ids)))) + agents = await Agent.objects.by_field_in("board_id", board_ids).all(session) targets: dict[str, Agent] = {} for agent in agents: @@ -242,15 +242,15 @@ async def list_board_group_memory( ) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]: await _require_group_access(session, group_id=group_id, ctx=ctx, write=False) statement = ( - select(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == group_id) + BoardGroupMemory.objects.filter_by(board_group_id=group_id) # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to # satisfy the NonEmptyStr response schema. - .where(func.length(func.trim(col(BoardGroupMemory.content))) > 0) + .filter(func.length(func.trim(col(BoardGroupMemory.content))) > 0) ) if is_chat is not None: - statement = statement.where(col(BoardGroupMemory.is_chat) == is_chat) + statement = statement.filter(col(BoardGroupMemory.is_chat) == is_chat) statement = statement.order_by(col(BoardGroupMemory.created_at).desc()) - return await paginate(session, statement) + return await paginate(session, statement.statement) @group_router.get("/stream") @@ -297,7 +297,7 @@ async def create_board_group_memory( ) -> BoardGroupMemory: group = await _require_group_access(session, group_id=group_id, ctx=ctx, write=True) - user = await session.get(User, ctx.member.user_id) + user = await User.objects.by_id(ctx.member.user_id).first(session) actor = ActorContext(actor_type="user", user=user) tags = set(payload.tags or []) is_chat = "chat" in tags @@ -332,19 +332,18 @@ async def list_board_group_memory_for_board( ) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]: group_id = board.board_group_id if group_id is None: - statement = select(BoardGroupMemory).where(col(BoardGroupMemory.id).is_(None)) - return await paginate(session, statement) + return await paginate(session, BoardGroupMemory.objects.by_ids([]).statement) - statement = ( - select(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == group_id) + queryset = ( + BoardGroupMemory.objects.filter_by(board_group_id=group_id) # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to # satisfy the NonEmptyStr response schema. - .where(func.length(func.trim(col(BoardGroupMemory.content))) > 0) + .filter(func.length(func.trim(col(BoardGroupMemory.content))) > 0) ) if is_chat is not None: - statement = statement.where(col(BoardGroupMemory.is_chat) == is_chat) - statement = statement.order_by(col(BoardGroupMemory.created_at).desc()) - return await paginate(session, statement) + queryset = queryset.filter(col(BoardGroupMemory.is_chat) == is_chat) + queryset = queryset.order_by(col(BoardGroupMemory.created_at).desc()) + return await paginate(session, queryset.statement) @board_router.get("/stream") @@ -396,7 +395,7 @@ async def create_board_group_memory_for_board( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Board is not in a board group", ) - group = await session.get(BoardGroup, group_id) + group = await BoardGroup.objects.by_id(group_id).first(session) if group is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) diff --git a/backend/app/api/board_groups.py b/backend/app/api/board_groups.py index 4ee39a4..a64e151 100644 --- a/backend/app/api/board_groups.py +++ b/backend/app/api/board_groups.py @@ -56,7 +56,7 @@ async def _require_group_access( member: OrganizationMember, write: bool, ) -> BoardGroup: - group = await session.get(BoardGroup, group_id) + group = await BoardGroup.objects.by_id(group_id).first(session) if group is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if group.organization_id != member.organization_id: @@ -67,9 +67,9 @@ async def _require_group_access( if not write and member_all_boards_read(member): return group - board_ids = list( - await session.exec(select(Board.id).where(col(Board.board_group_id) == group_id)) - ) + board_ids = [ + board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session) + ] if not board_ids: if is_org_admin(member): return group @@ -153,7 +153,7 @@ async def apply_board_group_heartbeat( session: AsyncSession = Depends(get_session), actor: ActorContext = Depends(require_admin_or_agent), ) -> BoardGroupHeartbeatApplyResult: - group = await session.get(BoardGroup, group_id) + group = await BoardGroup.objects.by_id(group_id).first(session) if group is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) @@ -181,11 +181,11 @@ async def apply_board_group_heartbeat( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) if not agent.is_board_lead: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - board = await session.get(Board, agent.board_id) + board = await Board.objects.by_id(agent.board_id).first(session) if board is None or board.board_group_id != group_id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - boards = list(await session.exec(select(Board).where(col(Board.board_group_id) == group_id))) + boards = await Board.objects.filter_by(board_group_id=group_id).all(session) board_by_id = {board.id: board for board in boards} board_ids = list(board_by_id.keys()) if not board_ids: @@ -196,7 +196,7 @@ async def apply_board_group_heartbeat( failed_agent_ids=[], ) - agents = list(await session.exec(select(Agent).where(col(Agent.board_id).in_(board_ids)))) + agents = await Agent.objects.by_field_in("board_id", board_ids).all(session) if not payload.include_board_leads: agents = [agent for agent in agents if not agent.is_board_lead] @@ -232,7 +232,7 @@ async def apply_board_group_heartbeat( failed_agent_ids: list[UUID] = [] gateway_ids = list(agents_by_gateway_id.keys()) - gateways = list(await session.exec(select(Gateway).where(col(Gateway.id).in_(gateway_ids)))) + gateways = await Gateway.objects.by_ids(gateway_ids).all(session) gateway_by_id = {gateway.id: gateway for gateway in gateways} for gateway_id, gateway_agents in agents_by_gateway_id.items(): gateway = gateway_by_id.get(gateway_id) diff --git a/backend/app/api/board_memory.py b/backend/app/api/board_memory.py index 197fe7b..dd277fe 100644 --- a/backend/app/api/board_memory.py +++ b/backend/app/api/board_memory.py @@ -8,7 +8,7 @@ from uuid import UUID from fastapi import APIRouter, Depends, Query, Request from sqlalchemy import func -from sqlmodel import col, select +from sqlmodel import col from sqlmodel.ext.asyncio.session import AsyncSession from sse_starlette.sse import EventSourceResponse @@ -58,7 +58,7 @@ def _serialize_memory(memory: BoardMemory) -> dict[str, object]: async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None: if board.gateway_id is None: return None - gateway = await session.get(Gateway, board.gateway_id) + gateway = await Gateway.objects.by_id(board.gateway_id).first(session) if gateway is None or not gateway.url: return None return GatewayClientConfig(url=gateway.url, token=gateway.token) @@ -83,17 +83,17 @@ async def _fetch_memory_events( is_chat: bool | None = None, ) -> list[BoardMemory]: statement = ( - select(BoardMemory).where(col(BoardMemory.board_id) == board_id) + BoardMemory.objects.filter_by(board_id=board_id) # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to # satisfy the NonEmptyStr response schema. - .where(func.length(func.trim(col(BoardMemory.content))) > 0) + .filter(func.length(func.trim(col(BoardMemory.content))) > 0) ) if is_chat is not None: - statement = statement.where(col(BoardMemory.is_chat) == is_chat) - statement = statement.where(col(BoardMemory.created_at) >= since).order_by( + statement = statement.filter(col(BoardMemory.is_chat) == is_chat) + statement = statement.filter(col(BoardMemory.created_at) >= since).order_by( col(BoardMemory.created_at) ) - return list(await session.exec(statement)) + return await statement.all(session) async def _notify_chat_targets( @@ -114,8 +114,7 @@ async def _notify_chat_targets( # Special-case control commands to reach all board agents. # These are intended to be parsed verbatim by agent runtimes. if command in {"/pause", "/resume"}: - statement = select(Agent).where(col(Agent.board_id) == board.id) - pause_targets: list[Agent] = list(await session.exec(statement)) + pause_targets: list[Agent] = await Agent.objects.filter_by(board_id=board.id).all(session) for agent in pause_targets: if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id: continue @@ -134,9 +133,8 @@ async def _notify_chat_targets( return mentions = extract_mentions(memory.content) - statement = select(Agent).where(col(Agent.board_id) == board.id) targets: dict[str, Agent] = {} - for agent in await session.exec(statement): + for agent in await Agent.objects.filter_by(board_id=board.id).all(session): if agent.is_board_lead: targets[str(agent.id)] = agent continue @@ -188,15 +186,15 @@ async def list_board_memory( actor: ActorContext = Depends(require_admin_or_agent), ) -> DefaultLimitOffsetPage[BoardMemoryRead]: statement = ( - select(BoardMemory).where(col(BoardMemory.board_id) == board.id) + BoardMemory.objects.filter_by(board_id=board.id) # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to # satisfy the NonEmptyStr response schema. - .where(func.length(func.trim(col(BoardMemory.content))) > 0) + .filter(func.length(func.trim(col(BoardMemory.content))) > 0) ) if is_chat is not None: - statement = statement.where(col(BoardMemory.is_chat) == is_chat) + statement = statement.filter(col(BoardMemory.is_chat) == is_chat) statement = statement.order_by(col(BoardMemory.created_at).desc()) - return await paginate(session, statement) + return await paginate(session, statement.statement) @router.get("/stream") diff --git a/backend/app/api/board_onboarding.py b/backend/app/api/board_onboarding.py index 1fc0cb8..7d58642 100644 --- a/backend/app/api/board_onboarding.py +++ b/backend/app/api/board_onboarding.py @@ -6,7 +6,7 @@ from uuid import uuid4 from fastapi import APIRouter, Depends, HTTPException, status from pydantic import ValidationError -from sqlmodel import col, select +from sqlmodel import col from sqlmodel.ext.asyncio.session import AsyncSession from app.api.deps import ( @@ -50,7 +50,7 @@ async def _gateway_config( ) -> tuple[Gateway, GatewayClientConfig]: if not board.gateway_id: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) - gateway = await session.get(Gateway, board.gateway_id) + gateway = await Gateway.objects.by_id(board.gateway_id).first(session) if gateway is None or not gateway.url or not gateway.main_session_key: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) return gateway, GatewayClientConfig(url=gateway.url, token=gateway.token) @@ -80,12 +80,10 @@ async def _ensure_lead_agent( identity_profile: dict[str, str] | None = None, ) -> Agent: existing = ( - await session.exec( - select(Agent) - .where(Agent.board_id == board.id) - .where(col(Agent.is_board_lead).is_(True)) - ) - ).first() + await Agent.objects.filter_by(board_id=board.id) + .filter(col(Agent.is_board_lead).is_(True)) + .first(session) + ) if existing: desired_name = agent_name or _lead_agent_name(board) if existing.name != desired_name: @@ -147,12 +145,10 @@ async def get_onboarding( session: AsyncSession = Depends(get_session), ) -> BoardOnboardingSession: onboarding = ( - await session.exec( - select(BoardOnboardingSession) - .where(BoardOnboardingSession.board_id == board.id) - .order_by(col(BoardOnboardingSession.created_at).desc()) - ) - ).first() + await BoardOnboardingSession.objects.filter_by(board_id=board.id) + .order_by(col(BoardOnboardingSession.updated_at).desc()) + .first(session) + ) if onboarding is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) return onboarding @@ -165,12 +161,10 @@ async def start_onboarding( session: AsyncSession = Depends(get_session), ) -> BoardOnboardingSession: onboarding = ( - await session.exec( - select(BoardOnboardingSession) - .where(BoardOnboardingSession.board_id == board.id) - .where(BoardOnboardingSession.status == "active") - ) - ).first() + await BoardOnboardingSession.objects.filter_by(board_id=board.id) + .filter(col(BoardOnboardingSession.status) == "active") + .first(session) + ) if onboarding: return onboarding @@ -248,12 +242,10 @@ async def answer_onboarding( session: AsyncSession = Depends(get_session), ) -> BoardOnboardingSession: onboarding = ( - await session.exec( - select(BoardOnboardingSession) - .where(BoardOnboardingSession.board_id == board.id) - .order_by(col(BoardOnboardingSession.created_at).desc()) - ) - ).first() + await BoardOnboardingSession.objects.filter_by(board_id=board.id) + .order_by(col(BoardOnboardingSession.updated_at).desc()) + .first(session) + ) if onboarding is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) @@ -295,18 +287,16 @@ async def agent_onboarding_update( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) if board.gateway_id: - gateway = await session.get(Gateway, board.gateway_id) + gateway = await Gateway.objects.by_id(board.gateway_id).first(session) if gateway and gateway.main_session_key and agent.openclaw_session_id: if agent.openclaw_session_id != gateway.main_session_key: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) onboarding = ( - await session.exec( - select(BoardOnboardingSession) - .where(BoardOnboardingSession.board_id == board.id) - .order_by(col(BoardOnboardingSession.created_at).desc()) - ) - ).first() + await BoardOnboardingSession.objects.filter_by(board_id=board.id) + .order_by(col(BoardOnboardingSession.updated_at).desc()) + .first(session) + ) if onboarding is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if onboarding.status == "confirmed": @@ -351,12 +341,10 @@ async def confirm_onboarding( auth: AuthContext = Depends(require_admin_auth), ) -> Board: onboarding = ( - await session.exec( - select(BoardOnboardingSession) - .where(BoardOnboardingSession.board_id == board.id) - .order_by(col(BoardOnboardingSession.created_at).desc()) - ) - ).first() + await BoardOnboardingSession.objects.filter_by(board_id=board.id) + .order_by(col(BoardOnboardingSession.updated_at).desc()) + .first(session) + ) if onboarding is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) diff --git a/backend/app/api/boards.py b/backend/app/api/boards.py index 7753a3f..f341572 100644 --- a/backend/app/api/boards.py +++ b/backend/app/api/boards.py @@ -163,7 +163,7 @@ async def _board_gateway( ) -> tuple[Gateway | None, GatewayClientConfig | None]: if not board.gateway_id: return None, None - config = await session.get(Gateway, board.gateway_id) + config = await Gateway.objects.by_id(board.gateway_id).first(session) if config is None: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, @@ -292,7 +292,7 @@ async def delete_board( session: AsyncSession = Depends(get_session), board: Board = Depends(get_board_for_user_write), ) -> OkResponse: - agents = list(await session.exec(select(Agent).where(Agent.board_id == board.id))) + agents = await Agent.objects.filter_by(board_id=board.id).all(session) task_ids = list(await session.exec(select(Task.id).where(Task.board_id == board.id))) config, client_config = await _board_gateway(session, board) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index d76c9f1..e7571be 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -59,7 +59,7 @@ async def require_org_member( member = await ensure_member_for_user(session, auth.user) if member is None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - organization = await session.get(Organization, member.organization_id) + organization = await Organization.objects.by_id(member.organization_id).first(session) if organization is None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return OrganizationContext(organization=organization, member=member) @@ -77,7 +77,7 @@ async def get_board_or_404( board_id: str, session: AsyncSession = Depends(get_session), ) -> Board: - board = await session.get(Board, board_id) + board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) return board @@ -88,7 +88,7 @@ async def get_board_for_actor_read( session: AsyncSession = Depends(get_session), actor: ActorContext = Depends(require_admin_or_agent), ) -> Board: - board = await session.get(Board, board_id) + board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if actor.actor_type == "agent": @@ -106,7 +106,7 @@ async def get_board_for_actor_write( session: AsyncSession = Depends(get_session), actor: ActorContext = Depends(require_admin_or_agent), ) -> Board: - board = await session.get(Board, board_id) + board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if actor.actor_type == "agent": @@ -124,7 +124,7 @@ async def get_board_for_user_read( session: AsyncSession = Depends(get_session), auth: AuthContext = Depends(get_auth_context), ) -> Board: - board = await session.get(Board, board_id) + board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if auth.user is None: @@ -138,7 +138,7 @@ async def get_board_for_user_write( session: AsyncSession = Depends(get_session), auth: AuthContext = Depends(get_auth_context), ) -> Board: - board = await session.get(Board, board_id) + board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if auth.user is None: @@ -152,7 +152,7 @@ async def get_task_or_404( board: Board = Depends(get_board_for_actor_read), session: AsyncSession = Depends(get_session), ) -> Task: - task = await session.get(Task, task_id) + task = await Task.objects.by_id(task_id).first(session) if task is None or task.board_id != board.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) return task diff --git a/backend/app/api/gateway.py b/backend/app/api/gateway.py index 83dcd6d..06ccf54 100644 --- a/backend/app/api/gateway.py +++ b/backend/app/api/gateway.py @@ -56,7 +56,7 @@ async def _resolve_gateway( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="board_id or gateway_url is required", ) - board = await session.get(Board, board_id) + board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") if isinstance(user, object) and user is not None: @@ -66,7 +66,7 @@ async def _resolve_gateway( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Board gateway_id is required", ) - gateway = await session.get(Gateway, board.gateway_id) + gateway = await Gateway.objects.by_id(board.gateway_id).first(session) if gateway is None: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, @@ -216,7 +216,7 @@ async def get_gateway_session( sessions_list = list(sessions.get("sessions") or []) else: sessions_list = list(sessions or []) - if main_session and not any(session.get("key") == main_session for session in sessions_list): + if main_session and not any(item.get("key") == main_session for item in sessions_list): try: await ensure_session(main_session, config=config, label="Main Agent") refreshed = await openclaw_call("sessions.list", config=config) diff --git a/backend/app/api/gateways.py b/backend/app/api/gateways.py index a6acf81..b37264e 100644 --- a/backend/app/api/gateways.py +++ b/backend/app/api/gateways.py @@ -2,12 +2,11 @@ from __future__ import annotations from uuid import UUID -from fastapi import APIRouter, Depends, Query -from sqlmodel import col, select +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlmodel import col from sqlmodel.ext.asyncio.session import AsyncSession from app.api.deps import require_org_admin -from app.api.queryset import api_qs from app.core.agent_tokens import generate_agent_token, hash_agent_token from app.core.auth import AuthContext, get_auth_context from app.core.time import utcnow @@ -43,14 +42,14 @@ async def _require_gateway( gateway_id: UUID, organization_id: UUID, ) -> Gateway: - return await ( - api_qs(Gateway) - .filter( - col(Gateway.id) == gateway_id, - col(Gateway.organization_id) == organization_id, - ) - .first_or_404(session, detail="Gateway not found") + gateway = ( + await Gateway.objects.by_id(gateway_id) + .filter(col(Gateway.organization_id) == organization_id) + .first(session) ) + if gateway is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found") + return gateway async def _find_main_agent( @@ -60,26 +59,22 @@ async def _find_main_agent( previous_session_key: str | None = None, ) -> Agent | None: if gateway.main_session_key: - agent = ( - await session.exec( - select(Agent).where(Agent.openclaw_session_id == gateway.main_session_key) - ) - ).first() + agent = await Agent.objects.filter_by(openclaw_session_id=gateway.main_session_key).first( + session + ) if agent: return agent if previous_session_key: - agent = ( - await session.exec( - select(Agent).where(Agent.openclaw_session_id == previous_session_key) - ) - ).first() + agent = await Agent.objects.filter_by(openclaw_session_id=previous_session_key).first( + session + ) if agent: return agent names = {_main_agent_name(gateway)} if previous_name: names.add(f"{previous_name} Main") for name in names: - agent = (await session.exec(select(Agent).where(Agent.name == name))).first() + agent = await Agent.objects.filter_by(name=name).first(session) if agent: return agent return None @@ -153,8 +148,7 @@ async def list_gateways( ctx: OrganizationContext = Depends(require_org_admin), ) -> DefaultLimitOffsetPage[GatewayRead]: statement = ( - api_qs(Gateway) - .filter(col(Gateway.organization_id) == ctx.organization.id) + Gateway.objects.filter_by(organization_id=ctx.organization.id) .order_by(col(Gateway.created_at).desc()) .statement ) diff --git a/backend/app/api/organizations.py b/backend/app/api/organizations.py index 73a9c60..6561dcd 100644 --- a/backend/app/api/organizations.py +++ b/backend/app/api/organizations.py @@ -10,7 +10,6 @@ from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession from app.api.deps import require_org_admin, require_org_member -from app.api.queryset import api_qs from app.core.auth import AuthContext, get_auth_context from app.core.time import utcnow from app.db import crud @@ -81,14 +80,10 @@ async def _require_org_member( organization_id: UUID, member_id: UUID, ) -> OrganizationMember: - return await ( - api_qs(OrganizationMember) - .filter( - col(OrganizationMember.id) == member_id, - col(OrganizationMember.organization_id) == organization_id, - ) - .first_or_404(session) - ) + member = await OrganizationMember.objects.by_id(member_id).first(session) + if member is None or member.organization_id != organization_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + return member async def _require_org_invite( @@ -97,14 +92,10 @@ async def _require_org_invite( organization_id: UUID, invite_id: UUID, ) -> OrganizationInvite: - return await ( - api_qs(OrganizationInvite) - .filter( - col(OrganizationInvite.id) == invite_id, - col(OrganizationInvite.organization_id) == organization_id, - ) - .first_or_404(session) - ) + invite = await OrganizationInvite.objects.by_id(invite_id).first(session) + if invite is None or invite.organization_id != organization_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + return invite @router.post("", response_model=OrganizationRead) @@ -157,7 +148,7 @@ async def list_my_organizations( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) await get_active_membership(session, auth.user) - db_user = await session.get(User, auth.user.id) + db_user = await User.objects.by_id(auth.user.id).first(session) active_id = db_user.active_organization_id if db_user else auth.user.active_organization_id statement = ( @@ -189,7 +180,7 @@ async def set_active_org( member = await set_active_organization( session, user=auth.user, organization_id=payload.organization_id ) - organization = await session.get(Organization, member.organization_id) + organization = await Organization.objects.by_id(member.organization_id).first(session) if organization is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) return OrganizationRead.model_validate(organization, from_attributes=True) @@ -293,14 +284,10 @@ async def get_my_membership( session: AsyncSession = Depends(get_session), ctx: OrganizationContext = Depends(require_org_member), ) -> OrganizationMemberRead: - user = await session.get(User, ctx.member.user_id) - access_rows = list( - await session.exec( - select(OrganizationBoardAccess).where( - col(OrganizationBoardAccess.organization_member_id) == ctx.member.id - ) - ) - ) + user = await User.objects.by_id(ctx.member.user_id).first(session) + access_rows = await OrganizationBoardAccess.objects.filter_by( + organization_member_id=ctx.member.id + ).all(session) model = _member_to_read(ctx.member, user) model.board_access = [ OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows @@ -342,14 +329,10 @@ async def get_org_member( ) if not is_org_admin(ctx.member) and member.user_id != ctx.member.user_id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - user = await session.get(User, member.user_id) - access_rows = list( - await session.exec( - select(OrganizationBoardAccess).where( - col(OrganizationBoardAccess.organization_member_id) == member.id - ) - ) - ) + user = await User.objects.by_id(member.user_id).first(session) + access_rows = await OrganizationBoardAccess.objects.filter_by( + organization_member_id=member.id + ).all(session) model = _member_to_read(member, user) model.board_access = [ OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows @@ -374,7 +357,7 @@ async def update_org_member( updates["role"] = normalize_role(updates["role"]) updates["updated_at"] = utcnow() member = await crud.patch(session, member, updates) - user = await session.get(User, member.user_id) + user = await User.objects.by_id(member.user_id).first(session) return _member_to_read(member, user) @@ -393,20 +376,19 @@ async def update_member_access( board_ids = {entry.board_id for entry in payload.board_access} if board_ids: - valid_board_ids = set( - await session.exec( - select(Board.id) - .where(col(Board.id).in_(board_ids)) - .where(col(Board.organization_id) == ctx.organization.id) - ) - ) + valid_board_ids = { + board.id + for board in await Board.objects.filter_by(organization_id=ctx.organization.id) + .filter(col(Board.id).in_(board_ids)) + .all(session) + } if valid_board_ids != board_ids: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) await apply_member_access_update(session, member=member, update=payload) await session.commit() await session.refresh(member) - user = await session.get(User, member.user_id) + user = await User.objects.by_id(member.user_id).first(session) return _member_to_read(member, user) @@ -416,9 +398,11 @@ async def remove_org_member( session: AsyncSession = Depends(get_session), ctx: OrganizationContext = Depends(require_org_admin), ) -> OkResponse: - member = await session.get(OrganizationMember, member_id) - if member is None or member.organization_id != ctx.organization.id: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + member = await _require_org_member( + session, + organization_id=ctx.organization.id, + member_id=member_id, + ) if member.user_id == ctx.member.user_id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -430,15 +414,12 @@ async def remove_org_member( detail="Only owners can remove owners", ) if member.role == "owner": - owner_ids = list( - await session.exec( - select(OrganizationMember.id).where( - col(OrganizationMember.organization_id) == ctx.organization.id, - col(OrganizationMember.role) == "owner", - ) - ) + owners = ( + await OrganizationMember.objects.filter_by(organization_id=ctx.organization.id) + .filter(col(OrganizationMember.role) == "owner") + .all(session) ) - if len(owner_ids) <= 1: + if len(owners) <= 1: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail="Organization must have at least one owner", @@ -451,17 +432,22 @@ async def remove_org_member( ), ) - user = await session.get(User, member.user_id) + user = await User.objects.by_id(member.user_id).first(session) if user is not None and user.active_organization_id == ctx.organization.id: - fallback_org_id = ( - await session.exec( - select(OrganizationMember.organization_id) - .where(col(OrganizationMember.user_id) == user.id) - .where(col(OrganizationMember.organization_id) != ctx.organization.id) - .order_by(col(OrganizationMember.created_at).asc()) + fallback_membership = ( + await OrganizationMember.objects.filter( + col(OrganizationMember.user_id) == user.id, + col(OrganizationMember.organization_id) != ctx.organization.id, + ) + .order_by(col(OrganizationMember.created_at).asc()) + .first(session) + ) + if isinstance(fallback_membership, UUID): + user.active_organization_id = fallback_membership + else: + user.active_organization_id = ( + fallback_membership.organization_id if fallback_membership is not None else None ) - ).first() - user.active_organization_id = fallback_org_id session.add(user) await crud.delete(session, member) @@ -474,8 +460,7 @@ async def list_org_invites( ctx: OrganizationContext = Depends(require_org_admin), ) -> DefaultLimitOffsetPage[OrganizationInviteRead]: statement = ( - api_qs(OrganizationInvite) - .filter(col(OrganizationInvite.organization_id) == ctx.organization.id) + OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id) .filter(col(OrganizationInvite.accepted_at).is_(None)) .order_by(col(OrganizationInvite.created_at).desc()) .statement @@ -522,13 +507,12 @@ async def create_org_invite( board_ids = {entry.board_id for entry in payload.board_access} if board_ids: - valid_board_ids = set( - await session.exec( - select(Board.id) - .where(col(Board.id).in_(board_ids)) - .where(col(Board.organization_id) == ctx.organization.id) - ) - ) + valid_board_ids = { + board.id + for board in await Board.objects.filter_by(organization_id=ctx.organization.id) + .filter(col(Board.id).in_(board_ids)) + .all(session) + } if valid_board_ids != board_ids: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) await apply_invite_board_access(session, invite=invite, entries=payload.board_access) @@ -566,13 +550,10 @@ async def accept_org_invite( ) -> OrganizationMemberRead: if auth.user is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - invite = ( - await session.exec( - select(OrganizationInvite) - .where(col(OrganizationInvite.token) == payload.token) - .where(col(OrganizationInvite.accepted_at).is_(None)) - ) - ).first() + invite = await OrganizationInvite.objects.filter( + col(OrganizationInvite.token) == payload.token, + col(OrganizationInvite.accepted_at).is_(None), + ).first(session) if invite is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if invite.invited_email and auth.user.email: @@ -597,5 +578,5 @@ async def accept_org_invite( await session.commit() member = existing - user = await session.get(User, member.user_id) + user = await User.objects.by_id(member.user_id).first(session) return _member_to_read(member, user) diff --git a/backend/app/api/tasks.py b/backend/app/api/tasks.py index 281e262..579cc52 100644 --- a/backend/app/api/tasks.py +++ b/backend/app/api/tasks.py @@ -243,7 +243,7 @@ def _serialize_comment(event: ActivityEvent) -> dict[str, object]: async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None: if not board.gateway_id: return None - gateway = await session.get(Gateway, board.gateway_id) + gateway = await Gateway.objects.by_id(board.gateway_id).first(session) if gateway is None or not gateway.url: return None return GatewayClientConfig(url=gateway.url, token=gateway.token) @@ -331,12 +331,10 @@ async def _notify_lead_on_task_create( task: Task, ) -> None: lead = ( - await session.exec( - select(Agent) - .where(Agent.board_id == board.id) - .where(col(Agent.is_board_lead).is_(True)) - ) - ).first() + await Agent.objects.filter_by(board_id=board.id) + .filter(col(Agent.is_board_lead).is_(True)) + .first(session) + ) if lead is None or not lead.openclaw_session_id: return config = await _gateway_config(session, board) @@ -390,12 +388,10 @@ async def _notify_lead_on_task_unassigned( task: Task, ) -> None: lead = ( - await session.exec( - select(Agent) - .where(Agent.board_id == board.id) - .where(col(Agent.is_board_lead).is_(True)) - ) - ).first() + await Agent.objects.filter_by(board_id=board.id) + .filter(col(Agent.is_board_lead).is_(True)) + .first(session) + ) if lead is None or not lead.openclaw_session_id: return config = await _gateway_config(session, board) @@ -635,7 +631,7 @@ async def create_task( await session.commit() await _notify_lead_on_task_create(session=session, board=board, task=task) if task.assigned_agent_id: - assigned_agent = await session.get(Agent, task.assigned_agent_id) + assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session) if assigned_agent: await _notify_agent_on_task_assign( session=session, @@ -670,7 +666,7 @@ async def update_task( ) board_id = task.board_id if actor.actor_type == "user" and actor.user is not None: - board = await session.get(Board, board_id) + board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) await require_board_access(session, user=actor.user, board=board, write=True) @@ -740,7 +736,7 @@ async def update_task( if "assigned_agent_id" in updates: assigned_id = updates["assigned_agent_id"] if assigned_id: - agent = await session.get(Agent, assigned_id) + agent = await Agent.objects.by_id(assigned_id).first(session) if agent is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if agent.is_board_lead: @@ -796,9 +792,13 @@ async def update_task( await session.refresh(task) if task.assigned_agent_id and task.assigned_agent_id != previous_assigned: - assigned_agent = await session.get(Agent, task.assigned_agent_id) + assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session) if assigned_agent: - board = await session.get(Board, task.board_id) if task.board_id else None + board = ( + await Board.objects.by_id(task.board_id).first(session) + if task.board_id + else None + ) if board: await _notify_agent_on_task_assign( session=session, @@ -879,7 +879,7 @@ async def update_task( task.in_progress_at = utcnow() if "assigned_agent_id" in updates and updates["assigned_agent_id"]: - agent = await session.get(Agent, updates["assigned_agent_id"]) + agent = await Agent.objects.by_id(updates["assigned_agent_id"]).first(session) if agent is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if agent.board_id and task.board_id and agent.board_id != task.board_id: @@ -941,7 +941,9 @@ async def update_task( if task.status == "inbox" and task.assigned_agent_id is None: if previous_status != "inbox" or previous_assigned is not None: - board = await session.get(Board, task.board_id) if task.board_id else None + board = ( + await Board.objects.by_id(task.board_id).first(session) if task.board_id else None + ) if board: await _notify_lead_on_task_unassigned( session=session, @@ -953,9 +955,13 @@ async def update_task( # Don't notify the actor about their own assignment. pass else: - assigned_agent = await session.get(Agent, task.assigned_agent_id) + assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session) if assigned_agent: - board = await session.get(Board, task.board_id) if task.board_id else None + board = ( + await Board.objects.by_id(task.board_id).first(session) + if task.board_id + else None + ) if board: await _notify_agent_on_task_assign( session=session, @@ -985,7 +991,7 @@ async def delete_task( ) -> OkResponse: if task.board_id is None: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) - board = await session.get(Board, task.board_id) + board = await Board.objects.by_id(task.board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if auth.user is None: @@ -1032,7 +1038,7 @@ async def create_task_comment( if task.board_id is None: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) if actor.actor_type == "user" and actor.user is not None: - board = await session.get(Board, task.board_id) + board = await Board.objects.by_id(task.board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) await require_board_access(session, user=actor.user, board=board, write=True) @@ -1059,18 +1065,17 @@ async def create_task_comment( mention_names = extract_mentions(payload.message) targets: dict[UUID, Agent] = {} if mention_names and task.board_id: - statement = select(Agent).where(col(Agent.board_id) == task.board_id) - for agent in await session.exec(statement): + for agent in await Agent.objects.filter_by(board_id=task.board_id).all(session): if matches_agent_mention(agent, mention_names): targets[agent.id] = agent if not mention_names and task.assigned_agent_id: - assigned_agent = await session.get(Agent, task.assigned_agent_id) + assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session) if assigned_agent: targets[assigned_agent.id] = assigned_agent if actor.actor_type == "agent" and actor.agent: targets.pop(actor.agent.id, None) if targets: - board = await session.get(Board, task.board_id) if task.board_id else None + board = await Board.objects.by_id(task.board_id).first(session) if task.board_id else None config = await _gateway_config(session, board) if board else None if board and config: snippet = payload.message.strip() diff --git a/backend/app/db/crud.py b/backend/app/db/crud.py index 5b789f9..e3221e2 100644 --- a/backend/app/db/crud.py +++ b/backend/app/db/crud.py @@ -27,7 +27,8 @@ def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectO async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None: - return await session.get(model, obj_id) + stmt = _lookup_statement(model, {"id": obj_id}).limit(1) + return (await session.exec(stmt)).first() async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT: diff --git a/backend/app/db/query_manager.py b/backend/app/db/query_manager.py new file mode 100644 index 0000000..32a42d8 --- /dev/null +++ b/backend/app/db/query_manager.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Generic, TypeVar + +from sqlalchemy import false +from sqlmodel import SQLModel, col + +from app.db.queryset import QuerySet, qs + +ModelT = TypeVar("ModelT", bound=SQLModel) + + +@dataclass(frozen=True) +class ModelManager(Generic[ModelT]): + model: type[ModelT] + id_field: str = "id" + + def all(self) -> QuerySet[ModelT]: + return qs(self.model) + + def none(self) -> QuerySet[ModelT]: + return qs(self.model).filter(false()) + + def filter(self, *criteria: Any) -> QuerySet[ModelT]: + return self.all().filter(*criteria) + + def where(self, *criteria: Any) -> QuerySet[ModelT]: + return self.filter(*criteria) + + def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]: + queryset = self.all() + for field_name, value in kwargs.items(): + queryset = queryset.filter(col(getattr(self.model, field_name)) == value) + return queryset + + def by_id(self, obj_id: Any) -> QuerySet[ModelT]: + return self.by_field(self.id_field, obj_id) + + def by_ids(self, obj_ids: list[Any] | tuple[Any, ...] | set[Any]) -> QuerySet[ModelT]: + return self.by_field_in(self.id_field, obj_ids) + + def by_field(self, field_name: str, value: Any) -> QuerySet[ModelT]: + return self.filter(col(getattr(self.model, field_name)) == value) + + def by_field_in( + self, + field_name: str, + values: list[Any] | tuple[Any, ...] | set[Any], + ) -> QuerySet[ModelT]: + seq = tuple(values) + if not seq: + return self.none() + return self.filter(col(getattr(self.model, field_name)).in_(seq)) + + +class ManagerDescriptor(Generic[ModelT]): + def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]: + return ModelManager(owner) diff --git a/backend/app/db/queryset.py b/backend/app/db/queryset.py index b6b3158..533b10b 100644 --- a/backend/app/db/queryset.py +++ b/backend/app/db/queryset.py @@ -17,6 +17,13 @@ class QuerySet(Generic[ModelT]): def filter(self, *criteria: Any) -> QuerySet[ModelT]: return replace(self, statement=self.statement.where(*criteria)) + def where(self, *criteria: Any) -> QuerySet[ModelT]: + return self.filter(*criteria) + + def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]: + statement = self.statement.filter_by(**kwargs) + return replace(self, statement=statement) + def order_by(self, *ordering: Any) -> QuerySet[ModelT]: return replace(self, statement=self.statement.order_by(*ordering)) diff --git a/backend/app/models/activity_events.py b/backend/app/models/activity_events.py index 15f7c69..bb1ae39 100644 --- a/backend/app/models/activity_events.py +++ b/backend/app/models/activity_events.py @@ -3,12 +3,13 @@ from __future__ import annotations from datetime import datetime from uuid import UUID, uuid4 -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class ActivityEvent(SQLModel, table=True): +class ActivityEvent(QueryModel, table=True): __tablename__ = "activity_events" id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/backend/app/models/agents.py b/backend/app/models/agents.py index cce8f82..675b82a 100644 --- a/backend/app/models/agents.py +++ b/backend/app/models/agents.py @@ -5,12 +5,13 @@ from typing import Any from uuid import UUID, uuid4 from sqlalchemy import JSON, Column, Text -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class Agent(SQLModel, table=True): +class Agent(QueryModel, table=True): __tablename__ = "agents" id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/backend/app/models/approvals.py b/backend/app/models/approvals.py index 8c762dc..db200b1 100644 --- a/backend/app/models/approvals.py +++ b/backend/app/models/approvals.py @@ -4,12 +4,13 @@ from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import JSON, Column -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class Approval(SQLModel, table=True): +class Approval(QueryModel, table=True): __tablename__ = "approvals" id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/backend/app/models/base.py b/backend/app/models/base.py new file mode 100644 index 0000000..bab7759 --- /dev/null +++ b/backend/app/models/base.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from typing import ClassVar, Self + +from sqlmodel import SQLModel + +from app.db.query_manager import ManagerDescriptor + + +class QueryModel(SQLModel, table=False): + objects: ClassVar[ManagerDescriptor[Self]] = ManagerDescriptor() diff --git a/backend/app/models/board_group_memory.py b/backend/app/models/board_group_memory.py index 96b14b7..92effde 100644 --- a/backend/app/models/board_group_memory.py +++ b/backend/app/models/board_group_memory.py @@ -4,12 +4,13 @@ from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import JSON, Column -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class BoardGroupMemory(SQLModel, table=True): +class BoardGroupMemory(QueryModel, table=True): __tablename__ = "board_group_memory" id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/backend/app/models/board_memory.py b/backend/app/models/board_memory.py index 296340e..0de3e3b 100644 --- a/backend/app/models/board_memory.py +++ b/backend/app/models/board_memory.py @@ -4,12 +4,13 @@ from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import JSON, Column -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class BoardMemory(SQLModel, table=True): +class BoardMemory(QueryModel, table=True): __tablename__ = "board_memory" id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/backend/app/models/board_onboarding.py b/backend/app/models/board_onboarding.py index b23d599..fafed3c 100644 --- a/backend/app/models/board_onboarding.py +++ b/backend/app/models/board_onboarding.py @@ -4,12 +4,13 @@ from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import JSON, Column -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class BoardOnboardingSession(SQLModel, table=True): +class BoardOnboardingSession(QueryModel, table=True): __tablename__ = "board_onboarding_sessions" id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/backend/app/models/gateways.py b/backend/app/models/gateways.py index 1cfd62c..fba5abd 100644 --- a/backend/app/models/gateways.py +++ b/backend/app/models/gateways.py @@ -3,12 +3,13 @@ from __future__ import annotations from datetime import datetime from uuid import UUID, uuid4 -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class Gateway(SQLModel, table=True): +class Gateway(QueryModel, table=True): __tablename__ = "gateways" id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/backend/app/models/organization_board_access.py b/backend/app/models/organization_board_access.py index a846a4f..b413507 100644 --- a/backend/app/models/organization_board_access.py +++ b/backend/app/models/organization_board_access.py @@ -4,12 +4,13 @@ from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import UniqueConstraint -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class OrganizationBoardAccess(SQLModel, table=True): +class OrganizationBoardAccess(QueryModel, table=True): __tablename__ = "organization_board_access" __table_args__ = ( UniqueConstraint( diff --git a/backend/app/models/organization_invite_board_access.py b/backend/app/models/organization_invite_board_access.py index 3c85019..f816115 100644 --- a/backend/app/models/organization_invite_board_access.py +++ b/backend/app/models/organization_invite_board_access.py @@ -4,12 +4,13 @@ from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import UniqueConstraint -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class OrganizationInviteBoardAccess(SQLModel, table=True): +class OrganizationInviteBoardAccess(QueryModel, table=True): __tablename__ = "organization_invite_board_access" __table_args__ = ( UniqueConstraint( diff --git a/backend/app/models/organization_invites.py b/backend/app/models/organization_invites.py index f4247d1..85228c2 100644 --- a/backend/app/models/organization_invites.py +++ b/backend/app/models/organization_invites.py @@ -4,12 +4,13 @@ from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import UniqueConstraint -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class OrganizationInvite(SQLModel, table=True): +class OrganizationInvite(QueryModel, table=True): __tablename__ = "organization_invites" __table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),) diff --git a/backend/app/models/organization_members.py b/backend/app/models/organization_members.py index 3b37f64..b521f8c 100644 --- a/backend/app/models/organization_members.py +++ b/backend/app/models/organization_members.py @@ -4,12 +4,13 @@ from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import UniqueConstraint -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class OrganizationMember(SQLModel, table=True): +class OrganizationMember(QueryModel, table=True): __tablename__ = "organization_members" __table_args__ = ( UniqueConstraint( diff --git a/backend/app/models/organizations.py b/backend/app/models/organizations.py index b306aac..3f51244 100644 --- a/backend/app/models/organizations.py +++ b/backend/app/models/organizations.py @@ -4,12 +4,13 @@ from datetime import datetime from uuid import UUID, uuid4 from sqlalchemy import UniqueConstraint -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class Organization(SQLModel, table=True): +class Organization(QueryModel, table=True): __tablename__ = "organizations" __table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),) diff --git a/backend/app/models/task_fingerprints.py b/backend/app/models/task_fingerprints.py index e720b36..f69ce74 100644 --- a/backend/app/models/task_fingerprints.py +++ b/backend/app/models/task_fingerprints.py @@ -3,12 +3,13 @@ from __future__ import annotations from datetime import datetime from uuid import UUID, uuid4 -from sqlmodel import Field, SQLModel +from sqlmodel import Field from app.core.time import utcnow +from app.models.base import QueryModel -class TaskFingerprint(SQLModel, table=True): +class TaskFingerprint(QueryModel, table=True): __tablename__ = "task_fingerprints" id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/backend/app/models/tenancy.py b/backend/app/models/tenancy.py index bd2768b..5d31db5 100644 --- a/backend/app/models/tenancy.py +++ b/backend/app/models/tenancy.py @@ -1,7 +1,7 @@ from __future__ import annotations -from sqlmodel import SQLModel +from app.models.base import QueryModel -class TenantScoped(SQLModel, table=False): +class TenantScoped(QueryModel, table=False): pass diff --git a/backend/app/models/users.py b/backend/app/models/users.py index fb73c7a..de4c6d0 100644 --- a/backend/app/models/users.py +++ b/backend/app/models/users.py @@ -2,10 +2,12 @@ from __future__ import annotations from uuid import UUID, uuid4 -from sqlmodel import Field, SQLModel +from sqlmodel import Field + +from app.models.base import QueryModel -class User(SQLModel, table=True): +class User(QueryModel, table=True): __tablename__ = "users" id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/backend/app/queries/__init__.py b/backend/app/queries/__init__.py deleted file mode 100644 index 9d48db4..0000000 --- a/backend/app/queries/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from __future__ import annotations diff --git a/backend/app/queries/organizations.py b/backend/app/queries/organizations.py deleted file mode 100644 index 9699825..0000000 --- a/backend/app/queries/organizations.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -from uuid import UUID - -from sqlmodel import col - -from app.db.queryset import QuerySet, qs -from app.models.organization_board_access import OrganizationBoardAccess -from app.models.organization_invites import OrganizationInvite -from app.models.organization_members import OrganizationMember -from app.models.organizations import Organization - - -def organization_by_name(name: str) -> QuerySet[Organization]: - return qs(Organization).filter(col(Organization.name) == name) - - -def member_by_user_and_org(*, user_id: UUID, organization_id: UUID) -> QuerySet[OrganizationMember]: - return qs(OrganizationMember).filter( - col(OrganizationMember.organization_id) == organization_id, - col(OrganizationMember.user_id) == user_id, - ) - - -def first_membership_for_user(user_id: UUID) -> QuerySet[OrganizationMember]: - return ( - qs(OrganizationMember) - .filter(col(OrganizationMember.user_id) == user_id) - .order_by(col(OrganizationMember.created_at).asc()) - ) - - -def pending_invite_by_email(email: str) -> QuerySet[OrganizationInvite]: - return ( - qs(OrganizationInvite) - .filter(col(OrganizationInvite.accepted_at).is_(None)) - .filter(col(OrganizationInvite.invited_email) == email) - .order_by(col(OrganizationInvite.created_at).asc()) - ) - - -def board_access_for_member_and_board( - *, - organization_member_id: UUID, - board_id: UUID, -) -> QuerySet[OrganizationBoardAccess]: - return qs(OrganizationBoardAccess).filter( - col(OrganizationBoardAccess.organization_member_id) == organization_member_id, - col(OrganizationBoardAccess.board_id) == board_id, - ) diff --git a/backend/app/services/board_group_snapshot.py b/backend/app/services/board_group_snapshot.py index 9d1a5d9..530b734 100644 --- a/backend/app/services/board_group_snapshot.py +++ b/backend/app/services/board_group_snapshot.py @@ -42,7 +42,7 @@ async def build_group_snapshot( include_done: bool = False, per_board_task_limit: int = 5, ) -> BoardGroupSnapshot: - statement = select(Board).where(col(Board.board_group_id) == group.id) + statement = Board.objects.filter_by(board_group_id=group.id).statement if exclude_board_id is not None: statement = statement.where(col(Board.id) != exclude_board_id) boards = list(await session.exec(statement.order_by(func.lower(col(Board.name)).asc()))) @@ -146,7 +146,7 @@ async def build_board_group_snapshot( ) -> BoardGroupSnapshot: if not board.board_group_id: return BoardGroupSnapshot(group=None, boards=[]) - group = await session.get(BoardGroup, board.board_group_id) + group = await BoardGroup.objects.by_id(board.board_group_id).first(session) if group is None: return BoardGroupSnapshot(group=None, boards=[]) return await build_group_snapshot( diff --git a/backend/app/services/board_snapshot.py b/backend/app/services/board_snapshot.py index f9c1582..6960ae0 100644 --- a/backend/app/services/board_snapshot.py +++ b/backend/app/services/board_snapshot.py @@ -97,9 +97,9 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap board_read = BoardRead.model_validate(board, from_attributes=True) tasks = list( - await session.exec( - select(Task).where(col(Task.board_id) == board.id).order_by(col(Task.created_at).desc()) - ) + await Task.objects.filter_by(board_id=board.id) + .order_by(col(Task.created_at).desc()) + .all(session) ) task_ids = [task.id for task in tasks] @@ -114,12 +114,10 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap ) main_session_keys = await _gateway_main_session_keys(session) - agents = list( - await session.exec( - select(Agent) - .where(col(Agent.board_id) == board.id) - .order_by(col(Agent.created_at).desc()) - ) + agents = ( + await Agent.objects.filter_by(board_id=board.id) + .order_by(col(Agent.created_at).desc()) + .all(session) ) agent_reads = [_agent_to_read(agent, main_session_keys) for agent in agents] agent_name_by_id = {agent.id: agent.name for agent in agents} @@ -134,13 +132,11 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap ).one() ) - approvals = list( - await session.exec( - select(Approval) - .where(col(Approval.board_id) == board.id) - .order_by(col(Approval.created_at).desc()) - .limit(200) - ) + approvals = ( + await Approval.objects.filter_by(board_id=board.id) + .order_by(col(Approval.created_at).desc()) + .limit(200) + .all(session) ) approval_reads = [_approval_to_read(approval) for approval in approvals] @@ -173,17 +169,15 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap for task in tasks ] - chat_messages = list( - await session.exec( - select(BoardMemory) - .where(col(BoardMemory.board_id) == board.id) - .where(col(BoardMemory.is_chat).is_(True)) - # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to - # satisfy the NonEmptyStr response schema. - .where(func.length(func.trim(col(BoardMemory.content))) > 0) - .order_by(col(BoardMemory.created_at).desc()) - .limit(200) - ) + chat_messages = ( + await BoardMemory.objects.filter_by(board_id=board.id) + .filter(col(BoardMemory.is_chat).is_(True)) + # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to + # satisfy the NonEmptyStr response schema. + .filter(func.length(func.trim(col(BoardMemory.content))) > 0) + .order_by(col(BoardMemory.created_at).desc()) + .limit(200) + .all(session) ) chat_messages.sort(key=lambda item: item.created_at) chat_reads = [_memory_to_read(memory) for memory in chat_messages] diff --git a/backend/app/services/organizations.py b/backend/app/services/organizations.py index dcb752d..b5efdd3 100644 --- a/backend/app/services/organizations.py +++ b/backend/app/services/organizations.py @@ -19,7 +19,6 @@ from app.models.organization_invites import OrganizationInvite from app.models.organization_members import OrganizationMember from app.models.organizations import Organization from app.models.users import User -from app.queries import organizations as org_queries from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate DEFAULT_ORG_NAME = "Personal" @@ -38,7 +37,7 @@ def is_org_admin(member: OrganizationMember) -> bool: async def get_default_org(session: AsyncSession) -> Organization | None: - return await org_queries.organization_by_name(DEFAULT_ORG_NAME).first(session) + return await Organization.objects.filter_by(name=DEFAULT_ORG_NAME).first(session) async def ensure_default_org(session: AsyncSession) -> Organization: @@ -58,14 +57,18 @@ async def get_member( user_id: UUID, organization_id: UUID, ) -> OrganizationMember | None: - return await org_queries.member_by_user_and_org( + return await OrganizationMember.objects.filter_by( user_id=user_id, organization_id=organization_id, ).first(session) async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None: - return await org_queries.first_membership_for_user(user_id).first(session) + return ( + await OrganizationMember.objects.filter_by(user_id=user_id) + .order_by(col(OrganizationMember.created_at).asc()) + .first(session) + ) async def set_active_organization( @@ -88,7 +91,7 @@ async def get_active_membership( session: AsyncSession, user: User, ) -> OrganizationMember | None: - db_user = await session.get(User, user.id) + db_user = await User.objects.by_id(user.id).first(session) if db_user is None: db_user = user if db_user.active_organization_id: @@ -119,7 +122,14 @@ async def _find_pending_invite( session: AsyncSession, email: str, ) -> OrganizationInvite | None: - return await org_queries.pending_invite_by_email(email).first(session) + return ( + await OrganizationInvite.objects.filter( + col(OrganizationInvite.accepted_at).is_(None), + col(OrganizationInvite.invited_email) == email, + ) + .order_by(col(OrganizationInvite.created_at).asc()) + .first(session) + ) async def accept_invite( @@ -230,7 +240,7 @@ async def has_board_access( else: if member_all_boards_read(member): return True - access = await org_queries.board_access_for_member_and_board( + access = await OrganizationBoardAccess.objects.filter_by( organization_member_id=member.id, board_id=board.id, ).first(session) diff --git a/backend/app/services/template_sync.py b/backend/app/services/template_sync.py index 77e322c..ab25ae2 100644 --- a/backend/app/services/template_sync.py +++ b/backend/app/services/template_sync.py @@ -328,7 +328,7 @@ async def sync_gateway_templates( result.errors.append(GatewayTemplatesSyncError(message=str(exc))) return result - boards = list(await session.exec(select(Board).where(col(Board.gateway_id) == gateway.id))) + boards = await Board.objects.filter_by(gateway_id=gateway.id).all(session) boards_by_id = {board.id: board for board in boards} if board_id is not None: board = boards_by_id.get(board_id) @@ -345,12 +345,10 @@ async def sync_gateway_templates( paused_board_ids = await _paused_board_ids(session, list(boards_by_id.keys())) if boards_by_id: - agents = list( - await session.exec( - select(Agent) - .where(col(Agent.board_id).in_(list(boards_by_id.keys()))) - .order_by(col(Agent.created_at).asc()) - ) + agents = await ( + Agent.objects.by_field_in("board_id", list(boards_by_id.keys())) + .order_by(col(Agent.created_at).asc()) + .all(session) ) else: agents = [] @@ -471,10 +469,10 @@ async def sync_gateway_templates( if include_main: main_agent = ( - await session.exec( - select(Agent).where(col(Agent.openclaw_session_id) == gateway.main_session_key) - ) - ).first() + await Agent.objects.all() + .filter(col(Agent.openclaw_session_id) == gateway.main_session_key) + .first(session) + ) if main_agent is None: result.errors.append( GatewayTemplatesSyncError( diff --git a/backend/tests/test_organizations_member_remove_api.py b/backend/tests/test_organizations_member_remove_api.py index 4478020..78310bc 100644 --- a/backend/tests/test_organizations_member_remove_api.py +++ b/backend/tests/test_organizations_member_remove_api.py @@ -28,7 +28,6 @@ class _FakeExecResult: @dataclass class _FakeSession: exec_results: list[Any] - get_results: dict[tuple[type[Any], Any], Any] = field(default_factory=dict) executed: list[Any] = field(default_factory=list) deleted: list[Any] = field(default_factory=list) @@ -44,9 +43,6 @@ class _FakeSession: raise AssertionError("No more exec_results left for session.exec") return self.exec_results.pop(0) - async def get(self, model: type[Any], key: Any) -> Any: - return self.get_results.get((model, key)) - async def execute(self, statement: Any) -> None: self.executed.append(statement) @@ -79,11 +75,11 @@ async def test_remove_org_member_deletes_member_access_and_member() -> None: active_organization_id=org_id, ) session = _FakeSession( - exec_results=[_FakeExecResult(first_value=fallback_org_id)], - get_results={ - (OrganizationMember, member_id): member, - (User, target_user_id): user, - }, + exec_results=[ + _FakeExecResult(first_value=member), + _FakeExecResult(first_value=user), + _FakeExecResult(first_value=fallback_org_id), + ], ) ctx = SimpleNamespace( organization=SimpleNamespace(id=org_id), @@ -110,10 +106,7 @@ async def test_remove_org_member_disallows_self_removal() -> None: user_id=user_id, role="member", ) - session = _FakeSession( - exec_results=[], - get_results={(OrganizationMember, member.id): member}, - ) + session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)]) ctx = SimpleNamespace( organization=SimpleNamespace(id=org_id), member=SimpleNamespace(user_id=user_id, role="owner"), @@ -137,10 +130,7 @@ async def test_remove_org_member_requires_owner_to_remove_owner() -> None: user_id=uuid4(), role="owner", ) - session = _FakeSession( - exec_results=[], - get_results={(OrganizationMember, member.id): member}, - ) + session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)]) ctx = SimpleNamespace( organization=SimpleNamespace(id=org_id), member=SimpleNamespace(user_id=uuid4(), role="admin"), @@ -165,8 +155,10 @@ async def test_remove_org_member_rejects_removing_last_owner() -> None: role="owner", ) session = _FakeSession( - exec_results=[_FakeExecResult(all_values=[member.id])], - get_results={(OrganizationMember, member.id): member}, + exec_results=[ + _FakeExecResult(first_value=member), + _FakeExecResult(all_values=[member]), + ], ) ctx = SimpleNamespace( organization=SimpleNamespace(id=org_id), diff --git a/backend/tests/test_task_dependencies_integration.py b/backend/tests/test_task_dependencies_integration.py index 5106785..da1d115 100644 --- a/backend/tests/test_task_dependencies_integration.py +++ b/backend/tests/test_task_dependencies_integration.py @@ -5,7 +5,7 @@ from uuid import UUID, uuid4 import pytest from fastapi import HTTPException from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from sqlmodel import SQLModel +from sqlmodel import SQLModel, col, select from sqlmodel.ext.asyncio.session import AsyncSession from app.models.boards import Board @@ -122,7 +122,7 @@ async def test_dependency_queries_and_replace_and_dependents() -> None: assert deps_map.get(t2, []) == [] # mark t2 done, t3 not - task2 = await session.get(Task, t2) + task2 = (await session.exec(select(Task).where(col(Task.id) == t2))).first() assert task2 is not None task2.status = td.DONE_STATUS await session.commit()