refactor: replace SQLModel with QueryModel in various models and update query methods

This commit is contained in:
Abhimanyu Saharan
2026-02-09 02:04:14 +05:30
parent e19e47106b
commit 228b99bc9b
40 changed files with 413 additions and 419 deletions

View File

@@ -102,7 +102,7 @@ def _guard_board_access(agent_ctx: AgentAuthContext, board: Board) -> None:
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig: async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig:
if not board.gateway_id: if not board.gateway_id:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
gateway = await session.get(Gateway, board.gateway_id) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url: if gateway is None or not gateway.url:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
return GatewayClientConfig(url=gateway.url, token=gateway.token) return GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -117,9 +117,7 @@ async def _require_gateway_main(
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Agent missing session key" status_code=status.HTTP_403_FORBIDDEN, detail="Agent missing session key"
) )
gateway = ( gateway = await Gateway.objects.filter_by(main_session_key=session_key).first(session)
await session.exec(select(Gateway).where(col(Gateway.main_session_key) == session_key))
).first()
if gateway is None: if gateway is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@@ -139,7 +137,7 @@ async def _require_gateway_board(
gateway: Gateway, gateway: Gateway,
board_id: UUID | str, board_id: UUID | str,
) -> Board: ) -> Board:
board = await session.get(Board, board_id) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
if board.gateway_id != gateway.id: if board.gateway_id != gateway.id:
@@ -254,7 +252,7 @@ async def create_task(
}, },
) )
if task.assigned_agent_id: if task.assigned_agent_id:
agent = await session.get(Agent, task.assigned_agent_id) agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if agent is None: if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if agent.is_board_lead: if agent.is_board_lead:
@@ -286,7 +284,7 @@ async def create_task(
) )
await session.commit() await session.commit()
if task.assigned_agent_id: if task.assigned_agent_id:
assigned_agent = await session.get(Agent, task.assigned_agent_id) assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if assigned_agent: if assigned_agent:
await tasks_api._notify_agent_on_task_assign( await tasks_api._notify_agent_on_task_assign(
session=session, session=session,
@@ -466,7 +464,7 @@ async def nudge_agent(
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead: if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
target = await session.get(Agent, agent_id) target = await Agent.objects.by_id(agent_id).first(session)
if target is None or (target.board_id and target.board_id != board.id): if target is None or (target.board_id and target.board_id != board.id):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if not target.openclaw_session_id: if not target.openclaw_session_id:
@@ -528,7 +526,7 @@ async def get_agent_soul(
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead and str(agent_ctx.agent.id) != agent_id: if not agent_ctx.agent.is_board_lead and str(agent_ctx.agent.id) != agent_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
target = await session.get(Agent, agent_id) target = await Agent.objects.by_id(agent_id).first(session)
if target is None or (target.board_id and target.board_id != board.id): if target is None or (target.board_id and target.board_id != board.id):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
config = await _gateway_config(session, board) config = await _gateway_config(session, board)
@@ -566,7 +564,7 @@ async def update_agent_soul(
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead: if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
target = await session.get(Agent, agent_id) target = await Agent.objects.by_id(agent_id).first(session)
if target is None or (target.board_id and target.board_id != board.id): if target is None or (target.board_id and target.board_id != board.id):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
config = await _gateway_config(session, board) config = await _gateway_config(session, board)
@@ -629,7 +627,7 @@ async def ask_user_via_gateway_main(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Board is not attached to a gateway", detail="Board is not attached to a gateway",
) )
gateway = await session.get(Gateway, board.gateway_id) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url: if gateway is None or not gateway.url:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -689,9 +687,7 @@ async def ask_user_via_gateway_main(
agent_id=agent_ctx.agent.id, agent_id=agent_ctx.agent.id,
) )
main_agent = ( main_agent = await Agent.objects.filter_by(openclaw_session_id=main_session_key).first(session)
await session.exec(select(Agent).where(col(Agent.openclaw_session_id) == main_session_key))
).first()
await session.commit() await session.commit()

View File

@@ -109,7 +109,7 @@ async def _require_board(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="board_id is required", detail="board_id is required",
) )
board = await session.get(Board, board_id) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
if user is not None: if user is not None:
@@ -125,7 +125,7 @@ async def _require_gateway(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Board gateway_id is required", detail="Board gateway_id is required",
) )
gateway = await session.get(Gateway, board.gateway_id) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None: if gateway is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -182,9 +182,7 @@ async def _find_gateway_for_main_session(
) -> Gateway | None: ) -> Gateway | None:
if not session_key: if not session_key:
return None return None
return ( return await Gateway.objects.filter_by(main_session_key=session_key).first(session)
await session.exec(select(Gateway).where(Gateway.main_session_key == session_key))
).first()
async def _ensure_gateway_session( async def _ensure_gateway_session(
@@ -237,7 +235,7 @@ async def _require_user_context(session: AsyncSession, user: User | None) -> Org
member = await get_active_membership(session, user) member = await get_active_membership(session, user)
if member is None: if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
organization = await session.get(Organization, member.organization_id) organization = await Organization.objects.by_id(member.organization_id).first(session)
if organization is None: if organization is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return OrganizationContext(organization=organization, member=member) return OrganizationContext(organization=organization, member=member)
@@ -258,7 +256,7 @@ async def _require_agent_access(
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return return
board = await session.get(Board, agent.board_id) board = await Board.objects.by_id(agent.board_id).first(session)
if board is None or board.organization_id != ctx.organization.id: if board is None or board.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if not await has_board_access(session, member=ctx.member, board=board, write=write): if not await has_board_access(session, member=ctx.member, board=board, write=write):
@@ -323,7 +321,7 @@ async def list_agents(
if board_id is not None: if board_id is not None:
statement = statement.where(col(Agent.board_id) == board_id) statement = statement.where(col(Agent.board_id) == board_id)
if gateway_id is not None: if gateway_id is not None:
gateway = await session.get(Gateway, gateway_id) gateway = await Gateway.objects.by_id(gateway_id).first(session)
if gateway is None or gateway.organization_id != ctx.organization.id: if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
statement = statement.join(Board, col(Agent.board_id) == col(Board.id)).where( statement = statement.join(Board, col(Agent.board_id) == col(Board.id)).where(
@@ -532,7 +530,7 @@ async def get_agent(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> AgentRead: ) -> AgentRead:
agent = await session.get(Agent, agent_id) agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None: if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await _require_agent_access(session, agent=agent, ctx=ctx, write=False) await _require_agent_access(session, agent=agent, ctx=ctx, write=False)
@@ -549,7 +547,7 @@ async def update_agent(
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> AgentRead: ) -> AgentRead:
agent = await session.get(Agent, agent_id) agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None: if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await _require_agent_access(session, agent=agent, ctx=ctx, write=True) await _require_agent_access(session, agent=agent, ctx=ctx, write=True)
@@ -728,7 +726,7 @@ async def heartbeat_agent(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = Depends(require_admin_or_agent),
) -> AgentRead: ) -> AgentRead:
agent = await session.get(Agent, agent_id) agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None: if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if actor.actor_type == "agent" and actor.agent and actor.agent.id != agent.id: if actor.actor_type == "agent" and actor.agent and actor.agent.id != agent.id:
@@ -767,7 +765,7 @@ async def heartbeat_or_create_agent(
actor=actor, actor=actor,
) )
statement = select(Agent).where(Agent.name == payload.name) statement = Agent.objects.filter_by(name=payload.name).statement
if payload.board_id is not None: if payload.board_id is not None:
statement = statement.where(Agent.board_id == payload.board_id) statement = statement.where(Agent.board_id == payload.board_id)
agent = (await session.exec(statement)).first() agent = (await session.exec(statement)).first()
@@ -943,7 +941,7 @@ async def delete_agent(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> OkResponse: ) -> OkResponse:
agent = await session.get(Agent, agent_id) agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None: if agent is None:
return OkResponse() return OkResponse()
await _require_agent_access(session, agent=agent, ctx=ctx, write=True) await _require_agent_access(session, agent=agent, ctx=ctx, write=True)

View File

@@ -77,9 +77,8 @@ async def _fetch_approval_events(
since: datetime, since: datetime,
) -> list[Approval]: ) -> list[Approval]:
statement = ( statement = (
select(Approval) Approval.objects.filter_by(board_id=board_id)
.where(col(Approval.board_id) == board_id) .filter(
.where(
or_( or_(
col(Approval.created_at) >= since, col(Approval.created_at) >= since,
col(Approval.resolved_at) >= since, col(Approval.resolved_at) >= since,
@@ -87,7 +86,7 @@ async def _fetch_approval_events(
) )
.order_by(asc(col(Approval.created_at))) .order_by(asc(col(Approval.created_at)))
) )
return list(await session.exec(statement)) return await statement.all(session)
@router.get("", response_model=DefaultLimitOffsetPage[ApprovalRead]) @router.get("", response_model=DefaultLimitOffsetPage[ApprovalRead])
@@ -97,11 +96,11 @@ async def list_approvals(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = Depends(require_admin_or_agent),
) -> DefaultLimitOffsetPage[ApprovalRead]: ) -> DefaultLimitOffsetPage[ApprovalRead]:
statement = select(Approval).where(col(Approval.board_id) == board.id) statement = Approval.objects.filter_by(board_id=board.id)
if status_filter: if status_filter:
statement = statement.where(col(Approval.status) == status_filter) statement = statement.filter(col(Approval.status) == status_filter)
statement = statement.order_by(col(Approval.created_at).desc()) statement = statement.order_by(col(Approval.created_at).desc())
return await paginate(session, statement) return await paginate(session, statement.statement)
@router.get("/stream") @router.get("/stream")
@@ -207,7 +206,7 @@ async def update_approval(
board: Board = Depends(get_board_for_user_write), board: Board = Depends(get_board_for_user_write),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
) -> Approval: ) -> Approval:
approval = await session.get(Approval, approval_id) approval = await Approval.objects.by_id(approval_id).first(session)
if approval is None or approval.board_id != board.id: if approval is None or approval.board_id != board.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
updates = payload.model_dump(exclude_unset=True) updates = payload.model_dump(exclude_unset=True)

View File

@@ -8,7 +8,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@@ -71,7 +71,7 @@ def _serialize_memory(memory: BoardGroupMemory) -> dict[str, object]:
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None: async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None:
if board.gateway_id is None: if board.gateway_id is None:
return None return None
gateway = await session.get(Gateway, board.gateway_id) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url: if gateway is None or not gateway.url:
return None return None
return GatewayClientConfig(url=gateway.url, token=gateway.token) return GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -96,17 +96,17 @@ async def _fetch_memory_events(
is_chat: bool | None = None, is_chat: bool | None = None,
) -> list[BoardGroupMemory]: ) -> list[BoardGroupMemory]:
statement = ( statement = (
select(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == board_group_id) BoardGroupMemory.objects.filter_by(board_group_id=board_group_id)
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema. # satisfy the NonEmptyStr response schema.
.where(func.length(func.trim(col(BoardGroupMemory.content))) > 0) .filter(func.length(func.trim(col(BoardGroupMemory.content))) > 0)
) )
if is_chat is not None: if is_chat is not None:
statement = statement.where(col(BoardGroupMemory.is_chat) == is_chat) statement = statement.filter(col(BoardGroupMemory.is_chat) == is_chat)
statement = statement.where(col(BoardGroupMemory.created_at) >= since).order_by( statement = statement.filter(col(BoardGroupMemory.created_at) >= since).order_by(
col(BoardGroupMemory.created_at) col(BoardGroupMemory.created_at)
) )
return list(await session.exec(statement)) return await statement.all(session)
async def _require_group_access( async def _require_group_access(
@@ -116,7 +116,7 @@ async def _require_group_access(
ctx: OrganizationContext, ctx: OrganizationContext,
write: bool, write: bool,
) -> BoardGroup: ) -> BoardGroup:
group = await session.get(BoardGroup, group_id) group = await BoardGroup.objects.by_id(group_id).first(session)
if group is None: if group is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if group.organization_id != ctx.member.organization_id: if group.organization_id != ctx.member.organization_id:
@@ -127,9 +127,9 @@ async def _require_group_access(
if not write and member_all_boards_read(ctx.member): if not write and member_all_boards_read(ctx.member):
return group return group
board_ids = list( board_ids = [
await session.exec(select(Board.id).where(col(Board.board_group_id) == group_id)) board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session)
) ]
if not board_ids: if not board_ids:
if is_org_admin(ctx.member): if is_org_admin(ctx.member):
return group return group
@@ -156,12 +156,12 @@ async def _notify_group_memory_targets(
is_broadcast = "broadcast" in tags or "all" in mentions is_broadcast = "broadcast" in tags or "all" in mentions
# Fetch group boards + agents. # Fetch group boards + agents.
boards = list(await session.exec(select(Board).where(col(Board.board_group_id) == group.id))) boards = await Board.objects.filter_by(board_group_id=group.id).all(session)
if not boards: if not boards:
return return
board_by_id = {board.id: board for board in boards} board_by_id = {board.id: board for board in boards}
board_ids = list(board_by_id.keys()) board_ids = list(board_by_id.keys())
agents = list(await session.exec(select(Agent).where(col(Agent.board_id).in_(board_ids)))) agents = await Agent.objects.by_field_in("board_id", board_ids).all(session)
targets: dict[str, Agent] = {} targets: dict[str, Agent] = {}
for agent in agents: for agent in agents:
@@ -242,15 +242,15 @@ async def list_board_group_memory(
) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]: ) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]:
await _require_group_access(session, group_id=group_id, ctx=ctx, write=False) await _require_group_access(session, group_id=group_id, ctx=ctx, write=False)
statement = ( statement = (
select(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == group_id) BoardGroupMemory.objects.filter_by(board_group_id=group_id)
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema. # satisfy the NonEmptyStr response schema.
.where(func.length(func.trim(col(BoardGroupMemory.content))) > 0) .filter(func.length(func.trim(col(BoardGroupMemory.content))) > 0)
) )
if is_chat is not None: if is_chat is not None:
statement = statement.where(col(BoardGroupMemory.is_chat) == is_chat) statement = statement.filter(col(BoardGroupMemory.is_chat) == is_chat)
statement = statement.order_by(col(BoardGroupMemory.created_at).desc()) statement = statement.order_by(col(BoardGroupMemory.created_at).desc())
return await paginate(session, statement) return await paginate(session, statement.statement)
@group_router.get("/stream") @group_router.get("/stream")
@@ -297,7 +297,7 @@ async def create_board_group_memory(
) -> BoardGroupMemory: ) -> BoardGroupMemory:
group = await _require_group_access(session, group_id=group_id, ctx=ctx, write=True) group = await _require_group_access(session, group_id=group_id, ctx=ctx, write=True)
user = await session.get(User, ctx.member.user_id) user = await User.objects.by_id(ctx.member.user_id).first(session)
actor = ActorContext(actor_type="user", user=user) actor = ActorContext(actor_type="user", user=user)
tags = set(payload.tags or []) tags = set(payload.tags or [])
is_chat = "chat" in tags is_chat = "chat" in tags
@@ -332,19 +332,18 @@ async def list_board_group_memory_for_board(
) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]: ) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]:
group_id = board.board_group_id group_id = board.board_group_id
if group_id is None: if group_id is None:
statement = select(BoardGroupMemory).where(col(BoardGroupMemory.id).is_(None)) return await paginate(session, BoardGroupMemory.objects.by_ids([]).statement)
return await paginate(session, statement)
statement = ( queryset = (
select(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == group_id) BoardGroupMemory.objects.filter_by(board_group_id=group_id)
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema. # satisfy the NonEmptyStr response schema.
.where(func.length(func.trim(col(BoardGroupMemory.content))) > 0) .filter(func.length(func.trim(col(BoardGroupMemory.content))) > 0)
) )
if is_chat is not None: if is_chat is not None:
statement = statement.where(col(BoardGroupMemory.is_chat) == is_chat) queryset = queryset.filter(col(BoardGroupMemory.is_chat) == is_chat)
statement = statement.order_by(col(BoardGroupMemory.created_at).desc()) queryset = queryset.order_by(col(BoardGroupMemory.created_at).desc())
return await paginate(session, statement) return await paginate(session, queryset.statement)
@board_router.get("/stream") @board_router.get("/stream")
@@ -396,7 +395,7 @@ async def create_board_group_memory_for_board(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Board is not in a board group", detail="Board is not in a board group",
) )
group = await session.get(BoardGroup, group_id) group = await BoardGroup.objects.by_id(group_id).first(session)
if group is None: if group is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)

View File

@@ -56,7 +56,7 @@ async def _require_group_access(
member: OrganizationMember, member: OrganizationMember,
write: bool, write: bool,
) -> BoardGroup: ) -> BoardGroup:
group = await session.get(BoardGroup, group_id) group = await BoardGroup.objects.by_id(group_id).first(session)
if group is None: if group is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if group.organization_id != member.organization_id: if group.organization_id != member.organization_id:
@@ -67,9 +67,9 @@ async def _require_group_access(
if not write and member_all_boards_read(member): if not write and member_all_boards_read(member):
return group return group
board_ids = list( board_ids = [
await session.exec(select(Board.id).where(col(Board.board_group_id) == group_id)) board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session)
) ]
if not board_ids: if not board_ids:
if is_org_admin(member): if is_org_admin(member):
return group return group
@@ -153,7 +153,7 @@ async def apply_board_group_heartbeat(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = Depends(require_admin_or_agent),
) -> BoardGroupHeartbeatApplyResult: ) -> BoardGroupHeartbeatApplyResult:
group = await session.get(BoardGroup, group_id) group = await BoardGroup.objects.by_id(group_id).first(session)
if group is None: if group is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -181,11 +181,11 @@ async def apply_board_group_heartbeat(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if not agent.is_board_lead: if not agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
board = await session.get(Board, agent.board_id) board = await Board.objects.by_id(agent.board_id).first(session)
if board is None or board.board_group_id != group_id: if board is None or board.board_group_id != group_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
boards = list(await session.exec(select(Board).where(col(Board.board_group_id) == group_id))) boards = await Board.objects.filter_by(board_group_id=group_id).all(session)
board_by_id = {board.id: board for board in boards} board_by_id = {board.id: board for board in boards}
board_ids = list(board_by_id.keys()) board_ids = list(board_by_id.keys())
if not board_ids: if not board_ids:
@@ -196,7 +196,7 @@ async def apply_board_group_heartbeat(
failed_agent_ids=[], failed_agent_ids=[],
) )
agents = list(await session.exec(select(Agent).where(col(Agent.board_id).in_(board_ids)))) agents = await Agent.objects.by_field_in("board_id", board_ids).all(session)
if not payload.include_board_leads: if not payload.include_board_leads:
agents = [agent for agent in agents if not agent.is_board_lead] agents = [agent for agent in agents if not agent.is_board_lead]
@@ -232,7 +232,7 @@ async def apply_board_group_heartbeat(
failed_agent_ids: list[UUID] = [] failed_agent_ids: list[UUID] = []
gateway_ids = list(agents_by_gateway_id.keys()) gateway_ids = list(agents_by_gateway_id.keys())
gateways = list(await session.exec(select(Gateway).where(col(Gateway.id).in_(gateway_ids)))) gateways = await Gateway.objects.by_ids(gateway_ids).all(session)
gateway_by_id = {gateway.id: gateway for gateway in gateways} gateway_by_id = {gateway.id: gateway for gateway in gateways}
for gateway_id, gateway_agents in agents_by_gateway_id.items(): for gateway_id, gateway_agents in agents_by_gateway_id.items():
gateway = gateway_by_id.get(gateway_id) gateway = gateway_by_id.get(gateway_id)

View File

@@ -8,7 +8,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, Query, Request from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@@ -58,7 +58,7 @@ def _serialize_memory(memory: BoardMemory) -> dict[str, object]:
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None: async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None:
if board.gateway_id is None: if board.gateway_id is None:
return None return None
gateway = await session.get(Gateway, board.gateway_id) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url: if gateway is None or not gateway.url:
return None return None
return GatewayClientConfig(url=gateway.url, token=gateway.token) return GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -83,17 +83,17 @@ async def _fetch_memory_events(
is_chat: bool | None = None, is_chat: bool | None = None,
) -> list[BoardMemory]: ) -> list[BoardMemory]:
statement = ( statement = (
select(BoardMemory).where(col(BoardMemory.board_id) == board_id) BoardMemory.objects.filter_by(board_id=board_id)
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema. # satisfy the NonEmptyStr response schema.
.where(func.length(func.trim(col(BoardMemory.content))) > 0) .filter(func.length(func.trim(col(BoardMemory.content))) > 0)
) )
if is_chat is not None: if is_chat is not None:
statement = statement.where(col(BoardMemory.is_chat) == is_chat) statement = statement.filter(col(BoardMemory.is_chat) == is_chat)
statement = statement.where(col(BoardMemory.created_at) >= since).order_by( statement = statement.filter(col(BoardMemory.created_at) >= since).order_by(
col(BoardMemory.created_at) col(BoardMemory.created_at)
) )
return list(await session.exec(statement)) return await statement.all(session)
async def _notify_chat_targets( async def _notify_chat_targets(
@@ -114,8 +114,7 @@ async def _notify_chat_targets(
# Special-case control commands to reach all board agents. # Special-case control commands to reach all board agents.
# These are intended to be parsed verbatim by agent runtimes. # These are intended to be parsed verbatim by agent runtimes.
if command in {"/pause", "/resume"}: if command in {"/pause", "/resume"}:
statement = select(Agent).where(col(Agent.board_id) == board.id) pause_targets: list[Agent] = await Agent.objects.filter_by(board_id=board.id).all(session)
pause_targets: list[Agent] = list(await session.exec(statement))
for agent in pause_targets: for agent in pause_targets:
if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id: if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id:
continue continue
@@ -134,9 +133,8 @@ async def _notify_chat_targets(
return return
mentions = extract_mentions(memory.content) mentions = extract_mentions(memory.content)
statement = select(Agent).where(col(Agent.board_id) == board.id)
targets: dict[str, Agent] = {} targets: dict[str, Agent] = {}
for agent in await session.exec(statement): for agent in await Agent.objects.filter_by(board_id=board.id).all(session):
if agent.is_board_lead: if agent.is_board_lead:
targets[str(agent.id)] = agent targets[str(agent.id)] = agent
continue continue
@@ -188,15 +186,15 @@ async def list_board_memory(
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = Depends(require_admin_or_agent),
) -> DefaultLimitOffsetPage[BoardMemoryRead]: ) -> DefaultLimitOffsetPage[BoardMemoryRead]:
statement = ( statement = (
select(BoardMemory).where(col(BoardMemory.board_id) == board.id) BoardMemory.objects.filter_by(board_id=board.id)
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
# satisfy the NonEmptyStr response schema. # satisfy the NonEmptyStr response schema.
.where(func.length(func.trim(col(BoardMemory.content))) > 0) .filter(func.length(func.trim(col(BoardMemory.content))) > 0)
) )
if is_chat is not None: if is_chat is not None:
statement = statement.where(col(BoardMemory.is_chat) == is_chat) statement = statement.filter(col(BoardMemory.is_chat) == is_chat)
statement = statement.order_by(col(BoardMemory.created_at).desc()) statement = statement.order_by(col(BoardMemory.created_at).desc())
return await paginate(session, statement) return await paginate(session, statement.statement)
@router.get("/stream") @router.get("/stream")

View File

@@ -6,7 +6,7 @@ from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import ValidationError from pydantic import ValidationError
from sqlmodel import col, select from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import ( from app.api.deps import (
@@ -50,7 +50,7 @@ async def _gateway_config(
) -> tuple[Gateway, GatewayClientConfig]: ) -> tuple[Gateway, GatewayClientConfig]:
if not board.gateway_id: if not board.gateway_id:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
gateway = await session.get(Gateway, board.gateway_id) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url or not gateway.main_session_key: if gateway is None or not gateway.url or not gateway.main_session_key:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
return gateway, GatewayClientConfig(url=gateway.url, token=gateway.token) return gateway, GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -80,12 +80,10 @@ async def _ensure_lead_agent(
identity_profile: dict[str, str] | None = None, identity_profile: dict[str, str] | None = None,
) -> Agent: ) -> Agent:
existing = ( existing = (
await session.exec( await Agent.objects.filter_by(board_id=board.id)
select(Agent) .filter(col(Agent.is_board_lead).is_(True))
.where(Agent.board_id == board.id) .first(session)
.where(col(Agent.is_board_lead).is_(True)) )
)
).first()
if existing: if existing:
desired_name = agent_name or _lead_agent_name(board) desired_name = agent_name or _lead_agent_name(board)
if existing.name != desired_name: if existing.name != desired_name:
@@ -147,12 +145,10 @@ async def get_onboarding(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
) -> BoardOnboardingSession: ) -> BoardOnboardingSession:
onboarding = ( onboarding = (
await session.exec( await BoardOnboardingSession.objects.filter_by(board_id=board.id)
select(BoardOnboardingSession) .order_by(col(BoardOnboardingSession.updated_at).desc())
.where(BoardOnboardingSession.board_id == board.id) .first(session)
.order_by(col(BoardOnboardingSession.created_at).desc()) )
)
).first()
if onboarding is None: if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return onboarding return onboarding
@@ -165,12 +161,10 @@ async def start_onboarding(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
) -> BoardOnboardingSession: ) -> BoardOnboardingSession:
onboarding = ( onboarding = (
await session.exec( await BoardOnboardingSession.objects.filter_by(board_id=board.id)
select(BoardOnboardingSession) .filter(col(BoardOnboardingSession.status) == "active")
.where(BoardOnboardingSession.board_id == board.id) .first(session)
.where(BoardOnboardingSession.status == "active") )
)
).first()
if onboarding: if onboarding:
return onboarding return onboarding
@@ -248,12 +242,10 @@ async def answer_onboarding(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
) -> BoardOnboardingSession: ) -> BoardOnboardingSession:
onboarding = ( onboarding = (
await session.exec( await BoardOnboardingSession.objects.filter_by(board_id=board.id)
select(BoardOnboardingSession) .order_by(col(BoardOnboardingSession.updated_at).desc())
.where(BoardOnboardingSession.board_id == board.id) .first(session)
.order_by(col(BoardOnboardingSession.created_at).desc()) )
)
).first()
if onboarding is None: if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -295,18 +287,16 @@ async def agent_onboarding_update(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if board.gateway_id: if board.gateway_id:
gateway = await session.get(Gateway, board.gateway_id) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway and gateway.main_session_key and agent.openclaw_session_id: if gateway and gateway.main_session_key and agent.openclaw_session_id:
if agent.openclaw_session_id != gateway.main_session_key: if agent.openclaw_session_id != gateway.main_session_key:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
onboarding = ( onboarding = (
await session.exec( await BoardOnboardingSession.objects.filter_by(board_id=board.id)
select(BoardOnboardingSession) .order_by(col(BoardOnboardingSession.updated_at).desc())
.where(BoardOnboardingSession.board_id == board.id) .first(session)
.order_by(col(BoardOnboardingSession.created_at).desc()) )
)
).first()
if onboarding is None: if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if onboarding.status == "confirmed": if onboarding.status == "confirmed":
@@ -351,12 +341,10 @@ async def confirm_onboarding(
auth: AuthContext = Depends(require_admin_auth), auth: AuthContext = Depends(require_admin_auth),
) -> Board: ) -> Board:
onboarding = ( onboarding = (
await session.exec( await BoardOnboardingSession.objects.filter_by(board_id=board.id)
select(BoardOnboardingSession) .order_by(col(BoardOnboardingSession.updated_at).desc())
.where(BoardOnboardingSession.board_id == board.id) .first(session)
.order_by(col(BoardOnboardingSession.created_at).desc()) )
)
).first()
if onboarding is None: if onboarding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)

View File

@@ -163,7 +163,7 @@ async def _board_gateway(
) -> tuple[Gateway | None, GatewayClientConfig | None]: ) -> tuple[Gateway | None, GatewayClientConfig | None]:
if not board.gateway_id: if not board.gateway_id:
return None, None return None, None
config = await session.get(Gateway, board.gateway_id) config = await Gateway.objects.by_id(board.gateway_id).first(session)
if config is None: if config is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -292,7 +292,7 @@ async def delete_board(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
board: Board = Depends(get_board_for_user_write), board: Board = Depends(get_board_for_user_write),
) -> OkResponse: ) -> OkResponse:
agents = list(await session.exec(select(Agent).where(Agent.board_id == board.id))) agents = await Agent.objects.filter_by(board_id=board.id).all(session)
task_ids = list(await session.exec(select(Task.id).where(Task.board_id == board.id))) task_ids = list(await session.exec(select(Task.id).where(Task.board_id == board.id)))
config, client_config = await _board_gateway(session, board) config, client_config = await _board_gateway(session, board)

View File

@@ -59,7 +59,7 @@ async def require_org_member(
member = await ensure_member_for_user(session, auth.user) member = await ensure_member_for_user(session, auth.user)
if member is None: if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
organization = await session.get(Organization, member.organization_id) organization = await Organization.objects.by_id(member.organization_id).first(session)
if organization is None: if organization is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return OrganizationContext(organization=organization, member=member) return OrganizationContext(organization=organization, member=member)
@@ -77,7 +77,7 @@ async def get_board_or_404(
board_id: str, board_id: str,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
) -> Board: ) -> Board:
board = await session.get(Board, board_id) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return board return board
@@ -88,7 +88,7 @@ async def get_board_for_actor_read(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = Depends(require_admin_or_agent),
) -> Board: ) -> Board:
board = await session.get(Board, board_id) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if actor.actor_type == "agent": if actor.actor_type == "agent":
@@ -106,7 +106,7 @@ async def get_board_for_actor_write(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = Depends(require_admin_or_agent),
) -> Board: ) -> Board:
board = await session.get(Board, board_id) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if actor.actor_type == "agent": if actor.actor_type == "agent":
@@ -124,7 +124,7 @@ async def get_board_for_user_read(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
) -> Board: ) -> Board:
board = await session.get(Board, board_id) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if auth.user is None: if auth.user is None:
@@ -138,7 +138,7 @@ async def get_board_for_user_write(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
) -> Board: ) -> Board:
board = await session.get(Board, board_id) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if auth.user is None: if auth.user is None:
@@ -152,7 +152,7 @@ async def get_task_or_404(
board: Board = Depends(get_board_for_actor_read), board: Board = Depends(get_board_for_actor_read),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
) -> Task: ) -> Task:
task = await session.get(Task, task_id) task = await Task.objects.by_id(task_id).first(session)
if task is None or task.board_id != board.id: if task is None or task.board_id != board.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return task return task

View File

@@ -56,7 +56,7 @@ async def _resolve_gateway(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="board_id or gateway_url is required", detail="board_id or gateway_url is required",
) )
board = await session.get(Board, board_id) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
if isinstance(user, object) and user is not None: if isinstance(user, object) and user is not None:
@@ -66,7 +66,7 @@ async def _resolve_gateway(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Board gateway_id is required", detail="Board gateway_id is required",
) )
gateway = await session.get(Gateway, board.gateway_id) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None: if gateway is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -216,7 +216,7 @@ async def get_gateway_session(
sessions_list = list(sessions.get("sessions") or []) sessions_list = list(sessions.get("sessions") or [])
else: else:
sessions_list = list(sessions or []) sessions_list = list(sessions or [])
if main_session and not any(session.get("key") == main_session for session in sessions_list): if main_session and not any(item.get("key") == main_session for item in sessions_list):
try: try:
await ensure_session(main_session, config=config, label="Main Agent") await ensure_session(main_session, config=config, label="Main Agent")
refreshed = await openclaw_call("sessions.list", config=config) refreshed = await openclaw_call("sessions.list", config=config)

View File

@@ -2,12 +2,11 @@ from __future__ import annotations
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import col, select from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import require_org_admin from app.api.deps import require_org_admin
from app.api.queryset import api_qs
from app.core.agent_tokens import generate_agent_token, hash_agent_token from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.auth import AuthContext, get_auth_context from app.core.auth import AuthContext, get_auth_context
from app.core.time import utcnow from app.core.time import utcnow
@@ -43,14 +42,14 @@ async def _require_gateway(
gateway_id: UUID, gateway_id: UUID,
organization_id: UUID, organization_id: UUID,
) -> Gateway: ) -> Gateway:
return await ( gateway = (
api_qs(Gateway) await Gateway.objects.by_id(gateway_id)
.filter( .filter(col(Gateway.organization_id) == organization_id)
col(Gateway.id) == gateway_id, .first(session)
col(Gateway.organization_id) == organization_id,
)
.first_or_404(session, detail="Gateway not found")
) )
if gateway is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
return gateway
async def _find_main_agent( async def _find_main_agent(
@@ -60,26 +59,22 @@ async def _find_main_agent(
previous_session_key: str | None = None, previous_session_key: str | None = None,
) -> Agent | None: ) -> Agent | None:
if gateway.main_session_key: if gateway.main_session_key:
agent = ( agent = await Agent.objects.filter_by(openclaw_session_id=gateway.main_session_key).first(
await session.exec( session
select(Agent).where(Agent.openclaw_session_id == gateway.main_session_key) )
)
).first()
if agent: if agent:
return agent return agent
if previous_session_key: if previous_session_key:
agent = ( agent = await Agent.objects.filter_by(openclaw_session_id=previous_session_key).first(
await session.exec( session
select(Agent).where(Agent.openclaw_session_id == previous_session_key) )
)
).first()
if agent: if agent:
return agent return agent
names = {_main_agent_name(gateway)} names = {_main_agent_name(gateway)}
if previous_name: if previous_name:
names.add(f"{previous_name} Main") names.add(f"{previous_name} Main")
for name in names: for name in names:
agent = (await session.exec(select(Agent).where(Agent.name == name))).first() agent = await Agent.objects.filter_by(name=name).first(session)
if agent: if agent:
return agent return agent
return None return None
@@ -153,8 +148,7 @@ async def list_gateways(
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> DefaultLimitOffsetPage[GatewayRead]: ) -> DefaultLimitOffsetPage[GatewayRead]:
statement = ( statement = (
api_qs(Gateway) Gateway.objects.filter_by(organization_id=ctx.organization.id)
.filter(col(Gateway.organization_id) == ctx.organization.id)
.order_by(col(Gateway.created_at).desc()) .order_by(col(Gateway.created_at).desc())
.statement .statement
) )

View File

@@ -10,7 +10,6 @@ from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import require_org_admin, require_org_member from app.api.deps import require_org_admin, require_org_member
from app.api.queryset import api_qs
from app.core.auth import AuthContext, get_auth_context from app.core.auth import AuthContext, get_auth_context
from app.core.time import utcnow from app.core.time import utcnow
from app.db import crud from app.db import crud
@@ -81,14 +80,10 @@ async def _require_org_member(
organization_id: UUID, organization_id: UUID,
member_id: UUID, member_id: UUID,
) -> OrganizationMember: ) -> OrganizationMember:
return await ( member = await OrganizationMember.objects.by_id(member_id).first(session)
api_qs(OrganizationMember) if member is None or member.organization_id != organization_id:
.filter( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
col(OrganizationMember.id) == member_id, return member
col(OrganizationMember.organization_id) == organization_id,
)
.first_or_404(session)
)
async def _require_org_invite( async def _require_org_invite(
@@ -97,14 +92,10 @@ async def _require_org_invite(
organization_id: UUID, organization_id: UUID,
invite_id: UUID, invite_id: UUID,
) -> OrganizationInvite: ) -> OrganizationInvite:
return await ( invite = await OrganizationInvite.objects.by_id(invite_id).first(session)
api_qs(OrganizationInvite) if invite is None or invite.organization_id != organization_id:
.filter( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
col(OrganizationInvite.id) == invite_id, return invite
col(OrganizationInvite.organization_id) == organization_id,
)
.first_or_404(session)
)
@router.post("", response_model=OrganizationRead) @router.post("", response_model=OrganizationRead)
@@ -157,7 +148,7 @@ async def list_my_organizations(
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
await get_active_membership(session, auth.user) await get_active_membership(session, auth.user)
db_user = await session.get(User, auth.user.id) db_user = await User.objects.by_id(auth.user.id).first(session)
active_id = db_user.active_organization_id if db_user else auth.user.active_organization_id active_id = db_user.active_organization_id if db_user else auth.user.active_organization_id
statement = ( statement = (
@@ -189,7 +180,7 @@ async def set_active_org(
member = await set_active_organization( member = await set_active_organization(
session, user=auth.user, organization_id=payload.organization_id session, user=auth.user, organization_id=payload.organization_id
) )
organization = await session.get(Organization, member.organization_id) organization = await Organization.objects.by_id(member.organization_id).first(session)
if organization is None: if organization is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return OrganizationRead.model_validate(organization, from_attributes=True) return OrganizationRead.model_validate(organization, from_attributes=True)
@@ -293,14 +284,10 @@ async def get_my_membership(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = Depends(require_org_member),
) -> OrganizationMemberRead: ) -> OrganizationMemberRead:
user = await session.get(User, ctx.member.user_id) user = await User.objects.by_id(ctx.member.user_id).first(session)
access_rows = list( access_rows = await OrganizationBoardAccess.objects.filter_by(
await session.exec( organization_member_id=ctx.member.id
select(OrganizationBoardAccess).where( ).all(session)
col(OrganizationBoardAccess.organization_member_id) == ctx.member.id
)
)
)
model = _member_to_read(ctx.member, user) model = _member_to_read(ctx.member, user)
model.board_access = [ model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
@@ -342,14 +329,10 @@ async def get_org_member(
) )
if not is_org_admin(ctx.member) and member.user_id != ctx.member.user_id: if not is_org_admin(ctx.member) and member.user_id != ctx.member.user_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
user = await session.get(User, member.user_id) user = await User.objects.by_id(member.user_id).first(session)
access_rows = list( access_rows = await OrganizationBoardAccess.objects.filter_by(
await session.exec( organization_member_id=member.id
select(OrganizationBoardAccess).where( ).all(session)
col(OrganizationBoardAccess.organization_member_id) == member.id
)
)
)
model = _member_to_read(member, user) model = _member_to_read(member, user)
model.board_access = [ model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
@@ -374,7 +357,7 @@ async def update_org_member(
updates["role"] = normalize_role(updates["role"]) updates["role"] = normalize_role(updates["role"])
updates["updated_at"] = utcnow() updates["updated_at"] = utcnow()
member = await crud.patch(session, member, updates) member = await crud.patch(session, member, updates)
user = await session.get(User, member.user_id) user = await User.objects.by_id(member.user_id).first(session)
return _member_to_read(member, user) return _member_to_read(member, user)
@@ -393,20 +376,19 @@ async def update_member_access(
board_ids = {entry.board_id for entry in payload.board_access} board_ids = {entry.board_id for entry in payload.board_access}
if board_ids: if board_ids:
valid_board_ids = set( valid_board_ids = {
await session.exec( board.id
select(Board.id) for board in await Board.objects.filter_by(organization_id=ctx.organization.id)
.where(col(Board.id).in_(board_ids)) .filter(col(Board.id).in_(board_ids))
.where(col(Board.organization_id) == ctx.organization.id) .all(session)
) }
)
if valid_board_ids != board_ids: if valid_board_ids != board_ids:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
await apply_member_access_update(session, member=member, update=payload) await apply_member_access_update(session, member=member, update=payload)
await session.commit() await session.commit()
await session.refresh(member) await session.refresh(member)
user = await session.get(User, member.user_id) user = await User.objects.by_id(member.user_id).first(session)
return _member_to_read(member, user) return _member_to_read(member, user)
@@ -416,9 +398,11 @@ async def remove_org_member(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> OkResponse: ) -> OkResponse:
member = await session.get(OrganizationMember, member_id) member = await _require_org_member(
if member is None or member.organization_id != ctx.organization.id: session,
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) organization_id=ctx.organization.id,
member_id=member_id,
)
if member.user_id == ctx.member.user_id: if member.user_id == ctx.member.user_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@@ -430,15 +414,12 @@ async def remove_org_member(
detail="Only owners can remove owners", detail="Only owners can remove owners",
) )
if member.role == "owner": if member.role == "owner":
owner_ids = list( owners = (
await session.exec( await OrganizationMember.objects.filter_by(organization_id=ctx.organization.id)
select(OrganizationMember.id).where( .filter(col(OrganizationMember.role) == "owner")
col(OrganizationMember.organization_id) == ctx.organization.id, .all(session)
col(OrganizationMember.role) == "owner",
)
)
) )
if len(owner_ids) <= 1: if len(owners) <= 1:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail="Organization must have at least one owner", detail="Organization must have at least one owner",
@@ -451,17 +432,22 @@ async def remove_org_member(
), ),
) )
user = await session.get(User, member.user_id) user = await User.objects.by_id(member.user_id).first(session)
if user is not None and user.active_organization_id == ctx.organization.id: if user is not None and user.active_organization_id == ctx.organization.id:
fallback_org_id = ( fallback_membership = (
await session.exec( await OrganizationMember.objects.filter(
select(OrganizationMember.organization_id) col(OrganizationMember.user_id) == user.id,
.where(col(OrganizationMember.user_id) == user.id) col(OrganizationMember.organization_id) != ctx.organization.id,
.where(col(OrganizationMember.organization_id) != ctx.organization.id) )
.order_by(col(OrganizationMember.created_at).asc()) .order_by(col(OrganizationMember.created_at).asc())
.first(session)
)
if isinstance(fallback_membership, UUID):
user.active_organization_id = fallback_membership
else:
user.active_organization_id = (
fallback_membership.organization_id if fallback_membership is not None else None
) )
).first()
user.active_organization_id = fallback_org_id
session.add(user) session.add(user)
await crud.delete(session, member) await crud.delete(session, member)
@@ -474,8 +460,7 @@ async def list_org_invites(
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> DefaultLimitOffsetPage[OrganizationInviteRead]: ) -> DefaultLimitOffsetPage[OrganizationInviteRead]:
statement = ( statement = (
api_qs(OrganizationInvite) OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id)
.filter(col(OrganizationInvite.organization_id) == ctx.organization.id)
.filter(col(OrganizationInvite.accepted_at).is_(None)) .filter(col(OrganizationInvite.accepted_at).is_(None))
.order_by(col(OrganizationInvite.created_at).desc()) .order_by(col(OrganizationInvite.created_at).desc())
.statement .statement
@@ -522,13 +507,12 @@ async def create_org_invite(
board_ids = {entry.board_id for entry in payload.board_access} board_ids = {entry.board_id for entry in payload.board_access}
if board_ids: if board_ids:
valid_board_ids = set( valid_board_ids = {
await session.exec( board.id
select(Board.id) for board in await Board.objects.filter_by(organization_id=ctx.organization.id)
.where(col(Board.id).in_(board_ids)) .filter(col(Board.id).in_(board_ids))
.where(col(Board.organization_id) == ctx.organization.id) .all(session)
) }
)
if valid_board_ids != board_ids: if valid_board_ids != board_ids:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
await apply_invite_board_access(session, invite=invite, entries=payload.board_access) await apply_invite_board_access(session, invite=invite, entries=payload.board_access)
@@ -566,13 +550,10 @@ async def accept_org_invite(
) -> OrganizationMemberRead: ) -> OrganizationMemberRead:
if auth.user is None: if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
invite = ( invite = await OrganizationInvite.objects.filter(
await session.exec( col(OrganizationInvite.token) == payload.token,
select(OrganizationInvite) col(OrganizationInvite.accepted_at).is_(None),
.where(col(OrganizationInvite.token) == payload.token) ).first(session)
.where(col(OrganizationInvite.accepted_at).is_(None))
)
).first()
if invite is None: if invite is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if invite.invited_email and auth.user.email: if invite.invited_email and auth.user.email:
@@ -597,5 +578,5 @@ async def accept_org_invite(
await session.commit() await session.commit()
member = existing member = existing
user = await session.get(User, member.user_id) user = await User.objects.by_id(member.user_id).first(session)
return _member_to_read(member, user) return _member_to_read(member, user)

View File

@@ -243,7 +243,7 @@ def _serialize_comment(event: ActivityEvent) -> dict[str, object]:
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None: async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None:
if not board.gateway_id: if not board.gateway_id:
return None return None
gateway = await session.get(Gateway, board.gateway_id) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway is None or not gateway.url: if gateway is None or not gateway.url:
return None return None
return GatewayClientConfig(url=gateway.url, token=gateway.token) return GatewayClientConfig(url=gateway.url, token=gateway.token)
@@ -331,12 +331,10 @@ async def _notify_lead_on_task_create(
task: Task, task: Task,
) -> None: ) -> None:
lead = ( lead = (
await session.exec( await Agent.objects.filter_by(board_id=board.id)
select(Agent) .filter(col(Agent.is_board_lead).is_(True))
.where(Agent.board_id == board.id) .first(session)
.where(col(Agent.is_board_lead).is_(True)) )
)
).first()
if lead is None or not lead.openclaw_session_id: if lead is None or not lead.openclaw_session_id:
return return
config = await _gateway_config(session, board) config = await _gateway_config(session, board)
@@ -390,12 +388,10 @@ async def _notify_lead_on_task_unassigned(
task: Task, task: Task,
) -> None: ) -> None:
lead = ( lead = (
await session.exec( await Agent.objects.filter_by(board_id=board.id)
select(Agent) .filter(col(Agent.is_board_lead).is_(True))
.where(Agent.board_id == board.id) .first(session)
.where(col(Agent.is_board_lead).is_(True)) )
)
).first()
if lead is None or not lead.openclaw_session_id: if lead is None or not lead.openclaw_session_id:
return return
config = await _gateway_config(session, board) config = await _gateway_config(session, board)
@@ -635,7 +631,7 @@ async def create_task(
await session.commit() await session.commit()
await _notify_lead_on_task_create(session=session, board=board, task=task) await _notify_lead_on_task_create(session=session, board=board, task=task)
if task.assigned_agent_id: if task.assigned_agent_id:
assigned_agent = await session.get(Agent, task.assigned_agent_id) assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if assigned_agent: if assigned_agent:
await _notify_agent_on_task_assign( await _notify_agent_on_task_assign(
session=session, session=session,
@@ -670,7 +666,7 @@ async def update_task(
) )
board_id = task.board_id board_id = task.board_id
if actor.actor_type == "user" and actor.user is not None: if actor.actor_type == "user" and actor.user is not None:
board = await session.get(Board, board_id) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await require_board_access(session, user=actor.user, board=board, write=True) await require_board_access(session, user=actor.user, board=board, write=True)
@@ -740,7 +736,7 @@ async def update_task(
if "assigned_agent_id" in updates: if "assigned_agent_id" in updates:
assigned_id = updates["assigned_agent_id"] assigned_id = updates["assigned_agent_id"]
if assigned_id: if assigned_id:
agent = await session.get(Agent, assigned_id) agent = await Agent.objects.by_id(assigned_id).first(session)
if agent is None: if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if agent.is_board_lead: if agent.is_board_lead:
@@ -796,9 +792,13 @@ async def update_task(
await session.refresh(task) await session.refresh(task)
if task.assigned_agent_id and task.assigned_agent_id != previous_assigned: if task.assigned_agent_id and task.assigned_agent_id != previous_assigned:
assigned_agent = await session.get(Agent, task.assigned_agent_id) assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if assigned_agent: if assigned_agent:
board = await session.get(Board, task.board_id) if task.board_id else None board = (
await Board.objects.by_id(task.board_id).first(session)
if task.board_id
else None
)
if board: if board:
await _notify_agent_on_task_assign( await _notify_agent_on_task_assign(
session=session, session=session,
@@ -879,7 +879,7 @@ async def update_task(
task.in_progress_at = utcnow() task.in_progress_at = utcnow()
if "assigned_agent_id" in updates and updates["assigned_agent_id"]: if "assigned_agent_id" in updates and updates["assigned_agent_id"]:
agent = await session.get(Agent, updates["assigned_agent_id"]) agent = await Agent.objects.by_id(updates["assigned_agent_id"]).first(session)
if agent is None: if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if agent.board_id and task.board_id and agent.board_id != task.board_id: if agent.board_id and task.board_id and agent.board_id != task.board_id:
@@ -941,7 +941,9 @@ async def update_task(
if task.status == "inbox" and task.assigned_agent_id is None: if task.status == "inbox" and task.assigned_agent_id is None:
if previous_status != "inbox" or previous_assigned is not None: if previous_status != "inbox" or previous_assigned is not None:
board = await session.get(Board, task.board_id) if task.board_id else None board = (
await Board.objects.by_id(task.board_id).first(session) if task.board_id else None
)
if board: if board:
await _notify_lead_on_task_unassigned( await _notify_lead_on_task_unassigned(
session=session, session=session,
@@ -953,9 +955,13 @@ async def update_task(
# Don't notify the actor about their own assignment. # Don't notify the actor about their own assignment.
pass pass
else: else:
assigned_agent = await session.get(Agent, task.assigned_agent_id) assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if assigned_agent: if assigned_agent:
board = await session.get(Board, task.board_id) if task.board_id else None board = (
await Board.objects.by_id(task.board_id).first(session)
if task.board_id
else None
)
if board: if board:
await _notify_agent_on_task_assign( await _notify_agent_on_task_assign(
session=session, session=session,
@@ -985,7 +991,7 @@ async def delete_task(
) -> OkResponse: ) -> OkResponse:
if task.board_id is None: if task.board_id is None:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
board = await session.get(Board, task.board_id) board = await Board.objects.by_id(task.board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if auth.user is None: if auth.user is None:
@@ -1032,7 +1038,7 @@ async def create_task_comment(
if task.board_id is None: if task.board_id is None:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
if actor.actor_type == "user" and actor.user is not None: if actor.actor_type == "user" and actor.user is not None:
board = await session.get(Board, task.board_id) board = await Board.objects.by_id(task.board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await require_board_access(session, user=actor.user, board=board, write=True) await require_board_access(session, user=actor.user, board=board, write=True)
@@ -1059,18 +1065,17 @@ async def create_task_comment(
mention_names = extract_mentions(payload.message) mention_names = extract_mentions(payload.message)
targets: dict[UUID, Agent] = {} targets: dict[UUID, Agent] = {}
if mention_names and task.board_id: if mention_names and task.board_id:
statement = select(Agent).where(col(Agent.board_id) == task.board_id) for agent in await Agent.objects.filter_by(board_id=task.board_id).all(session):
for agent in await session.exec(statement):
if matches_agent_mention(agent, mention_names): if matches_agent_mention(agent, mention_names):
targets[agent.id] = agent targets[agent.id] = agent
if not mention_names and task.assigned_agent_id: if not mention_names and task.assigned_agent_id:
assigned_agent = await session.get(Agent, task.assigned_agent_id) assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session)
if assigned_agent: if assigned_agent:
targets[assigned_agent.id] = assigned_agent targets[assigned_agent.id] = assigned_agent
if actor.actor_type == "agent" and actor.agent: if actor.actor_type == "agent" and actor.agent:
targets.pop(actor.agent.id, None) targets.pop(actor.agent.id, None)
if targets: if targets:
board = await session.get(Board, task.board_id) if task.board_id else None board = await Board.objects.by_id(task.board_id).first(session) if task.board_id else None
config = await _gateway_config(session, board) if board else None config = await _gateway_config(session, board) if board else None
if board and config: if board and config:
snippet = payload.message.strip() snippet = payload.message.strip()

View File

@@ -27,7 +27,8 @@ def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectO
async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None: async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None:
return await session.get(model, obj_id) stmt = _lookup_statement(model, {"id": obj_id}).limit(1)
return (await session.exec(stmt)).first()
async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT: async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT:

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from sqlalchemy import false
from sqlmodel import SQLModel, col
from app.db.queryset import QuerySet, qs
ModelT = TypeVar("ModelT", bound=SQLModel)
@dataclass(frozen=True)
class ModelManager(Generic[ModelT]):
model: type[ModelT]
id_field: str = "id"
def all(self) -> QuerySet[ModelT]:
return qs(self.model)
def none(self) -> QuerySet[ModelT]:
return qs(self.model).filter(false())
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
return self.all().filter(*criteria)
def where(self, *criteria: Any) -> QuerySet[ModelT]:
return self.filter(*criteria)
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
queryset = self.all()
for field_name, value in kwargs.items():
queryset = queryset.filter(col(getattr(self.model, field_name)) == value)
return queryset
def by_id(self, obj_id: Any) -> QuerySet[ModelT]:
return self.by_field(self.id_field, obj_id)
def by_ids(self, obj_ids: list[Any] | tuple[Any, ...] | set[Any]) -> QuerySet[ModelT]:
return self.by_field_in(self.id_field, obj_ids)
def by_field(self, field_name: str, value: Any) -> QuerySet[ModelT]:
return self.filter(col(getattr(self.model, field_name)) == value)
def by_field_in(
self,
field_name: str,
values: list[Any] | tuple[Any, ...] | set[Any],
) -> QuerySet[ModelT]:
seq = tuple(values)
if not seq:
return self.none()
return self.filter(col(getattr(self.model, field_name)).in_(seq))
class ManagerDescriptor(Generic[ModelT]):
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
return ModelManager(owner)

View File

@@ -17,6 +17,13 @@ class QuerySet(Generic[ModelT]):
def filter(self, *criteria: Any) -> QuerySet[ModelT]: def filter(self, *criteria: Any) -> QuerySet[ModelT]:
return replace(self, statement=self.statement.where(*criteria)) return replace(self, statement=self.statement.where(*criteria))
def where(self, *criteria: Any) -> QuerySet[ModelT]:
return self.filter(*criteria)
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
statement = self.statement.filter_by(**kwargs)
return replace(self, statement=statement)
def order_by(self, *ordering: Any) -> QuerySet[ModelT]: def order_by(self, *ordering: Any) -> QuerySet[ModelT]:
return replace(self, statement=self.statement.order_by(*ordering)) return replace(self, statement=self.statement.order_by(*ordering))

View File

@@ -3,12 +3,13 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class ActivityEvent(SQLModel, table=True): class ActivityEvent(QueryModel, table=True):
__tablename__ = "activity_events" __tablename__ = "activity_events"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -5,12 +5,13 @@ from typing import Any
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import JSON, Column, Text from sqlalchemy import JSON, Column, Text
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class Agent(SQLModel, table=True): class Agent(QueryModel, table=True):
__tablename__ = "agents" __tablename__ = "agents"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import JSON, Column from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class Approval(SQLModel, table=True): class Approval(QueryModel, table=True):
__tablename__ = "approvals" __tablename__ = "approvals"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from typing import ClassVar, Self
from sqlmodel import SQLModel
from app.db.query_manager import ManagerDescriptor
class QueryModel(SQLModel, table=False):
objects: ClassVar[ManagerDescriptor[Self]] = ManagerDescriptor()

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import JSON, Column from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class BoardGroupMemory(SQLModel, table=True): class BoardGroupMemory(QueryModel, table=True):
__tablename__ = "board_group_memory" __tablename__ = "board_group_memory"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import JSON, Column from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class BoardMemory(SQLModel, table=True): class BoardMemory(QueryModel, table=True):
__tablename__ = "board_memory" __tablename__ = "board_memory"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import JSON, Column from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class BoardOnboardingSession(SQLModel, table=True): class BoardOnboardingSession(QueryModel, table=True):
__tablename__ = "board_onboarding_sessions" __tablename__ = "board_onboarding_sessions"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -3,12 +3,13 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class Gateway(SQLModel, table=True): class Gateway(QueryModel, table=True):
__tablename__ = "gateways" __tablename__ = "gateways"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class OrganizationBoardAccess(SQLModel, table=True): class OrganizationBoardAccess(QueryModel, table=True):
__tablename__ = "organization_board_access" __tablename__ = "organization_board_access"
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class OrganizationInviteBoardAccess(SQLModel, table=True): class OrganizationInviteBoardAccess(QueryModel, table=True):
__tablename__ = "organization_invite_board_access" __tablename__ = "organization_invite_board_access"
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class OrganizationInvite(SQLModel, table=True): class OrganizationInvite(QueryModel, table=True):
__tablename__ = "organization_invites" __tablename__ = "organization_invites"
__table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),) __table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),)

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class OrganizationMember(SQLModel, table=True): class OrganizationMember(QueryModel, table=True):
__tablename__ = "organization_members" __tablename__ = "organization_members"
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(

View File

@@ -4,12 +4,13 @@ from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class Organization(SQLModel, table=True): class Organization(QueryModel, table=True):
__tablename__ = "organizations" __tablename__ = "organizations"
__table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),) __table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),)

View File

@@ -3,12 +3,13 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.core.time import utcnow from app.core.time import utcnow
from app.models.base import QueryModel
class TaskFingerprint(SQLModel, table=True): class TaskFingerprint(QueryModel, table=True):
__tablename__ = "task_fingerprints" __tablename__ = "task_fingerprints"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from sqlmodel import SQLModel from app.models.base import QueryModel
class TenantScoped(SQLModel, table=False): class TenantScoped(QueryModel, table=False):
pass pass

View File

@@ -2,10 +2,12 @@ from __future__ import annotations
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel from sqlmodel import Field
from app.models.base import QueryModel
class User(SQLModel, table=True): class User(QueryModel, table=True):
__tablename__ = "users" __tablename__ = "users"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1 +0,0 @@
from __future__ import annotations

View File

@@ -1,50 +0,0 @@
from __future__ import annotations
from uuid import UUID
from sqlmodel import col
from app.db.queryset import QuerySet, qs
from app.models.organization_board_access import OrganizationBoardAccess
from app.models.organization_invites import OrganizationInvite
from app.models.organization_members import OrganizationMember
from app.models.organizations import Organization
def organization_by_name(name: str) -> QuerySet[Organization]:
return qs(Organization).filter(col(Organization.name) == name)
def member_by_user_and_org(*, user_id: UUID, organization_id: UUID) -> QuerySet[OrganizationMember]:
return qs(OrganizationMember).filter(
col(OrganizationMember.organization_id) == organization_id,
col(OrganizationMember.user_id) == user_id,
)
def first_membership_for_user(user_id: UUID) -> QuerySet[OrganizationMember]:
return (
qs(OrganizationMember)
.filter(col(OrganizationMember.user_id) == user_id)
.order_by(col(OrganizationMember.created_at).asc())
)
def pending_invite_by_email(email: str) -> QuerySet[OrganizationInvite]:
return (
qs(OrganizationInvite)
.filter(col(OrganizationInvite.accepted_at).is_(None))
.filter(col(OrganizationInvite.invited_email) == email)
.order_by(col(OrganizationInvite.created_at).asc())
)
def board_access_for_member_and_board(
*,
organization_member_id: UUID,
board_id: UUID,
) -> QuerySet[OrganizationBoardAccess]:
return qs(OrganizationBoardAccess).filter(
col(OrganizationBoardAccess.organization_member_id) == organization_member_id,
col(OrganizationBoardAccess.board_id) == board_id,
)

View File

@@ -42,7 +42,7 @@ async def build_group_snapshot(
include_done: bool = False, include_done: bool = False,
per_board_task_limit: int = 5, per_board_task_limit: int = 5,
) -> BoardGroupSnapshot: ) -> BoardGroupSnapshot:
statement = select(Board).where(col(Board.board_group_id) == group.id) statement = Board.objects.filter_by(board_group_id=group.id).statement
if exclude_board_id is not None: if exclude_board_id is not None:
statement = statement.where(col(Board.id) != exclude_board_id) statement = statement.where(col(Board.id) != exclude_board_id)
boards = list(await session.exec(statement.order_by(func.lower(col(Board.name)).asc()))) boards = list(await session.exec(statement.order_by(func.lower(col(Board.name)).asc())))
@@ -146,7 +146,7 @@ async def build_board_group_snapshot(
) -> BoardGroupSnapshot: ) -> BoardGroupSnapshot:
if not board.board_group_id: if not board.board_group_id:
return BoardGroupSnapshot(group=None, boards=[]) return BoardGroupSnapshot(group=None, boards=[])
group = await session.get(BoardGroup, board.board_group_id) group = await BoardGroup.objects.by_id(board.board_group_id).first(session)
if group is None: if group is None:
return BoardGroupSnapshot(group=None, boards=[]) return BoardGroupSnapshot(group=None, boards=[])
return await build_group_snapshot( return await build_group_snapshot(

View File

@@ -97,9 +97,9 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
board_read = BoardRead.model_validate(board, from_attributes=True) board_read = BoardRead.model_validate(board, from_attributes=True)
tasks = list( tasks = list(
await session.exec( await Task.objects.filter_by(board_id=board.id)
select(Task).where(col(Task.board_id) == board.id).order_by(col(Task.created_at).desc()) .order_by(col(Task.created_at).desc())
) .all(session)
) )
task_ids = [task.id for task in tasks] task_ids = [task.id for task in tasks]
@@ -114,12 +114,10 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
) )
main_session_keys = await _gateway_main_session_keys(session) main_session_keys = await _gateway_main_session_keys(session)
agents = list( agents = (
await session.exec( await Agent.objects.filter_by(board_id=board.id)
select(Agent) .order_by(col(Agent.created_at).desc())
.where(col(Agent.board_id) == board.id) .all(session)
.order_by(col(Agent.created_at).desc())
)
) )
agent_reads = [_agent_to_read(agent, main_session_keys) for agent in agents] agent_reads = [_agent_to_read(agent, main_session_keys) for agent in agents]
agent_name_by_id = {agent.id: agent.name for agent in agents} agent_name_by_id = {agent.id: agent.name for agent in agents}
@@ -134,13 +132,11 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
).one() ).one()
) )
approvals = list( approvals = (
await session.exec( await Approval.objects.filter_by(board_id=board.id)
select(Approval) .order_by(col(Approval.created_at).desc())
.where(col(Approval.board_id) == board.id) .limit(200)
.order_by(col(Approval.created_at).desc()) .all(session)
.limit(200)
)
) )
approval_reads = [_approval_to_read(approval) for approval in approvals] approval_reads = [_approval_to_read(approval) for approval in approvals]
@@ -173,17 +169,15 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
for task in tasks for task in tasks
] ]
chat_messages = list( chat_messages = (
await session.exec( await BoardMemory.objects.filter_by(board_id=board.id)
select(BoardMemory) .filter(col(BoardMemory.is_chat).is_(True))
.where(col(BoardMemory.board_id) == board.id) # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
.where(col(BoardMemory.is_chat).is_(True)) # satisfy the NonEmptyStr response schema.
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to .filter(func.length(func.trim(col(BoardMemory.content))) > 0)
# satisfy the NonEmptyStr response schema. .order_by(col(BoardMemory.created_at).desc())
.where(func.length(func.trim(col(BoardMemory.content))) > 0) .limit(200)
.order_by(col(BoardMemory.created_at).desc()) .all(session)
.limit(200)
)
) )
chat_messages.sort(key=lambda item: item.created_at) chat_messages.sort(key=lambda item: item.created_at)
chat_reads = [_memory_to_read(memory) for memory in chat_messages] chat_reads = [_memory_to_read(memory) for memory in chat_messages]

View File

@@ -19,7 +19,6 @@ from app.models.organization_invites import OrganizationInvite
from app.models.organization_members import OrganizationMember from app.models.organization_members import OrganizationMember
from app.models.organizations import Organization from app.models.organizations import Organization
from app.models.users import User from app.models.users import User
from app.queries import organizations as org_queries
from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate
DEFAULT_ORG_NAME = "Personal" DEFAULT_ORG_NAME = "Personal"
@@ -38,7 +37,7 @@ def is_org_admin(member: OrganizationMember) -> bool:
async def get_default_org(session: AsyncSession) -> Organization | None: async def get_default_org(session: AsyncSession) -> Organization | None:
return await org_queries.organization_by_name(DEFAULT_ORG_NAME).first(session) return await Organization.objects.filter_by(name=DEFAULT_ORG_NAME).first(session)
async def ensure_default_org(session: AsyncSession) -> Organization: async def ensure_default_org(session: AsyncSession) -> Organization:
@@ -58,14 +57,18 @@ async def get_member(
user_id: UUID, user_id: UUID,
organization_id: UUID, organization_id: UUID,
) -> OrganizationMember | None: ) -> OrganizationMember | None:
return await org_queries.member_by_user_and_org( return await OrganizationMember.objects.filter_by(
user_id=user_id, user_id=user_id,
organization_id=organization_id, organization_id=organization_id,
).first(session) ).first(session)
async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None: async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None:
return await org_queries.first_membership_for_user(user_id).first(session) return (
await OrganizationMember.objects.filter_by(user_id=user_id)
.order_by(col(OrganizationMember.created_at).asc())
.first(session)
)
async def set_active_organization( async def set_active_organization(
@@ -88,7 +91,7 @@ async def get_active_membership(
session: AsyncSession, session: AsyncSession,
user: User, user: User,
) -> OrganizationMember | None: ) -> OrganizationMember | None:
db_user = await session.get(User, user.id) db_user = await User.objects.by_id(user.id).first(session)
if db_user is None: if db_user is None:
db_user = user db_user = user
if db_user.active_organization_id: if db_user.active_organization_id:
@@ -119,7 +122,14 @@ async def _find_pending_invite(
session: AsyncSession, session: AsyncSession,
email: str, email: str,
) -> OrganizationInvite | None: ) -> OrganizationInvite | None:
return await org_queries.pending_invite_by_email(email).first(session) return (
await OrganizationInvite.objects.filter(
col(OrganizationInvite.accepted_at).is_(None),
col(OrganizationInvite.invited_email) == email,
)
.order_by(col(OrganizationInvite.created_at).asc())
.first(session)
)
async def accept_invite( async def accept_invite(
@@ -230,7 +240,7 @@ async def has_board_access(
else: else:
if member_all_boards_read(member): if member_all_boards_read(member):
return True return True
access = await org_queries.board_access_for_member_and_board( access = await OrganizationBoardAccess.objects.filter_by(
organization_member_id=member.id, organization_member_id=member.id,
board_id=board.id, board_id=board.id,
).first(session) ).first(session)

View File

@@ -328,7 +328,7 @@ async def sync_gateway_templates(
result.errors.append(GatewayTemplatesSyncError(message=str(exc))) result.errors.append(GatewayTemplatesSyncError(message=str(exc)))
return result return result
boards = list(await session.exec(select(Board).where(col(Board.gateway_id) == gateway.id))) boards = await Board.objects.filter_by(gateway_id=gateway.id).all(session)
boards_by_id = {board.id: board for board in boards} boards_by_id = {board.id: board for board in boards}
if board_id is not None: if board_id is not None:
board = boards_by_id.get(board_id) board = boards_by_id.get(board_id)
@@ -345,12 +345,10 @@ async def sync_gateway_templates(
paused_board_ids = await _paused_board_ids(session, list(boards_by_id.keys())) paused_board_ids = await _paused_board_ids(session, list(boards_by_id.keys()))
if boards_by_id: if boards_by_id:
agents = list( agents = await (
await session.exec( Agent.objects.by_field_in("board_id", list(boards_by_id.keys()))
select(Agent) .order_by(col(Agent.created_at).asc())
.where(col(Agent.board_id).in_(list(boards_by_id.keys()))) .all(session)
.order_by(col(Agent.created_at).asc())
)
) )
else: else:
agents = [] agents = []
@@ -471,10 +469,10 @@ async def sync_gateway_templates(
if include_main: if include_main:
main_agent = ( main_agent = (
await session.exec( await Agent.objects.all()
select(Agent).where(col(Agent.openclaw_session_id) == gateway.main_session_key) .filter(col(Agent.openclaw_session_id) == gateway.main_session_key)
) .first(session)
).first() )
if main_agent is None: if main_agent is None:
result.errors.append( result.errors.append(
GatewayTemplatesSyncError( GatewayTemplatesSyncError(

View File

@@ -28,7 +28,6 @@ class _FakeExecResult:
@dataclass @dataclass
class _FakeSession: class _FakeSession:
exec_results: list[Any] exec_results: list[Any]
get_results: dict[tuple[type[Any], Any], Any] = field(default_factory=dict)
executed: list[Any] = field(default_factory=list) executed: list[Any] = field(default_factory=list)
deleted: list[Any] = field(default_factory=list) deleted: list[Any] = field(default_factory=list)
@@ -44,9 +43,6 @@ class _FakeSession:
raise AssertionError("No more exec_results left for session.exec") raise AssertionError("No more exec_results left for session.exec")
return self.exec_results.pop(0) return self.exec_results.pop(0)
async def get(self, model: type[Any], key: Any) -> Any:
return self.get_results.get((model, key))
async def execute(self, statement: Any) -> None: async def execute(self, statement: Any) -> None:
self.executed.append(statement) self.executed.append(statement)
@@ -79,11 +75,11 @@ async def test_remove_org_member_deletes_member_access_and_member() -> None:
active_organization_id=org_id, active_organization_id=org_id,
) )
session = _FakeSession( session = _FakeSession(
exec_results=[_FakeExecResult(first_value=fallback_org_id)], exec_results=[
get_results={ _FakeExecResult(first_value=member),
(OrganizationMember, member_id): member, _FakeExecResult(first_value=user),
(User, target_user_id): user, _FakeExecResult(first_value=fallback_org_id),
}, ],
) )
ctx = SimpleNamespace( ctx = SimpleNamespace(
organization=SimpleNamespace(id=org_id), organization=SimpleNamespace(id=org_id),
@@ -110,10 +106,7 @@ async def test_remove_org_member_disallows_self_removal() -> None:
user_id=user_id, user_id=user_id,
role="member", role="member",
) )
session = _FakeSession( session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)])
exec_results=[],
get_results={(OrganizationMember, member.id): member},
)
ctx = SimpleNamespace( ctx = SimpleNamespace(
organization=SimpleNamespace(id=org_id), organization=SimpleNamespace(id=org_id),
member=SimpleNamespace(user_id=user_id, role="owner"), member=SimpleNamespace(user_id=user_id, role="owner"),
@@ -137,10 +130,7 @@ async def test_remove_org_member_requires_owner_to_remove_owner() -> None:
user_id=uuid4(), user_id=uuid4(),
role="owner", role="owner",
) )
session = _FakeSession( session = _FakeSession(exec_results=[_FakeExecResult(first_value=member)])
exec_results=[],
get_results={(OrganizationMember, member.id): member},
)
ctx = SimpleNamespace( ctx = SimpleNamespace(
organization=SimpleNamespace(id=org_id), organization=SimpleNamespace(id=org_id),
member=SimpleNamespace(user_id=uuid4(), role="admin"), member=SimpleNamespace(user_id=uuid4(), role="admin"),
@@ -165,8 +155,10 @@ async def test_remove_org_member_rejects_removing_last_owner() -> None:
role="owner", role="owner",
) )
session = _FakeSession( session = _FakeSession(
exec_results=[_FakeExecResult(all_values=[member.id])], exec_results=[
get_results={(OrganizationMember, member.id): member}, _FakeExecResult(first_value=member),
_FakeExecResult(all_values=[member]),
],
) )
ctx = SimpleNamespace( ctx = SimpleNamespace(
organization=SimpleNamespace(id=org_id), organization=SimpleNamespace(id=org_id),

View File

@@ -5,7 +5,7 @@ from uuid import UUID, uuid4
import pytest import pytest
from fastapi import HTTPException from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import SQLModel from sqlmodel import SQLModel, col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.boards import Board from app.models.boards import Board
@@ -122,7 +122,7 @@ async def test_dependency_queries_and_replace_and_dependents() -> None:
assert deps_map.get(t2, []) == [] assert deps_map.get(t2, []) == []
# mark t2 done, t3 not # mark t2 done, t3 not
task2 = await session.get(Task, t2) task2 = (await session.exec(select(Task).where(col(Task.id) == t2))).first()
assert task2 is not None assert task2 is not None
task2.status = td.DONE_STATUS task2.status = td.DONE_STATUS
await session.commit() await session.commit()