From 77069432094de1b4daa04a855492bf663cbcc995 Mon Sep 17 00:00:00 2001 From: Abhimanyu Saharan Date: Mon, 9 Feb 2026 16:23:41 +0530 Subject: [PATCH] refactor: enhance docstrings for clarity and consistency across multiple files --- backend/app/api/activity.py | 87 ++- backend/app/api/approvals.py | 90 ++- backend/app/api/board_group_memory.py | 318 ++++++--- backend/app/api/board_groups.py | 246 +++++-- backend/app/api/board_memory.py | 182 +++-- backend/app/api/board_onboarding.py | 42 +- backend/app/api/deps.py | 77 +- backend/app/api/gateway.py | 217 +++--- backend/app/api/gateways.py | 155 +++-- backend/app/api/metrics.py | 57 +- backend/app/db/crud.py | 130 +++- backend/app/integrations/openclaw_gateway.py | 47 +- backend/app/schemas/agents.py | 36 +- backend/app/schemas/board_onboarding.py | 49 +- backend/app/services/agent_provisioning.py | 55 +- backend/app/services/board_group_snapshot.py | 209 ++++-- backend/app/services/organizations.py | 23 +- backend/app/services/task_dependencies.py | 65 +- backend/app/services/template_sync.py | 658 ++++++++++-------- .../tests/test_agent_provisioning_utils.py | 2 + backend/tests/test_db_transaction_safety.py | 2 + backend/tests/test_error_handling.py | 2 + backend/tests/test_lead_policy.py | 2 + backend/tests/test_mentions.py | 2 + .../test_organizations_member_remove_api.py | 2 + backend/tests/test_request_id_middleware.py | 2 + backend/tests/test_task_dependencies.py | 2 + .../test_task_dependencies_integration.py | 2 + 28 files changed, 1829 insertions(+), 932 deletions(-) diff --git a/backend/app/api/activity.py b/backend/app/api/activity.py index ca59fc4..01c3786 100644 --- a/backend/app/api/activity.py +++ b/backend/app/api/activity.py @@ -1,17 +1,18 @@ +"""Activity listing and task-comment feed endpoints.""" + from __future__ import annotations import asyncio import json from collections import deque -from collections.abc import AsyncIterator, Sequence +from collections.abc import Sequence from datetime import datetime, timezone -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from sqlalchemy import asc, desc, func from sqlmodel import col, select -from sqlmodel.ext.asyncio.session import AsyncSession from sse_starlette.sse import EventSourceResponse from app.api.deps import ActorContext, require_admin_or_agent, require_org_member @@ -22,7 +23,10 @@ from app.models.activity_events import ActivityEvent from app.models.agents import Agent from app.models.boards import Board from app.models.tasks import Task -from app.schemas.activity_events import ActivityEventRead, ActivityTaskCommentFeedItemRead +from app.schemas.activity_events import ( + ActivityEventRead, + ActivityTaskCommentFeedItemRead, +) from app.schemas.pagination import DefaultLimitOffsetPage from app.services.organizations import ( OrganizationContext, @@ -30,9 +34,21 @@ from app.services.organizations import ( list_accessible_board_ids, ) +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from sqlmodel.ext.asyncio.session import AsyncSession + router = APIRouter(prefix="/activity", tags=["activity"]) SSE_SEEN_MAX = 2000 +STREAM_POLL_SECONDS = 2 +SESSION_DEP = Depends(get_session) +ACTOR_DEP = Depends(require_admin_or_agent) +ORG_MEMBER_DEP = Depends(require_org_member) +BOARD_ID_QUERY = Query(default=None) +SINCE_QUERY = Query(default=None) +_RUNTIME_TYPE_REFERENCES = (UUID,) def _parse_since(value: str | None) -> datetime | None: @@ -110,9 +126,10 @@ async def _fetch_task_comment_events( @router.get("", response_model=DefaultLimitOffsetPage[ActivityEventRead]) async def list_activity( - session: AsyncSession = Depends(get_session), - actor: ActorContext = Depends(require_admin_or_agent), + session: AsyncSession = SESSION_DEP, + actor: ActorContext = ACTOR_DEP, ) -> DefaultLimitOffsetPage[ActivityEventRead]: + """List activity events visible to the calling actor.""" statement = select(ActivityEvent) if actor.actor_type == "agent" and actor.agent: statement = statement.where(ActivityEvent.agent_id == actor.agent.id) @@ -124,9 +141,10 @@ async def list_activity( if not board_ids: statement = statement.where(col(ActivityEvent.id).is_(None)) else: - statement = statement.join(Task, col(ActivityEvent.task_id) == col(Task.id)).where( - col(Task.board_id).in_(board_ids) - ) + statement = statement.join( + Task, + col(ActivityEvent.task_id) == col(Task.id), + ).where(col(Task.board_id).in_(board_ids)) statement = statement.order_by(desc(col(ActivityEvent.created_at))) return await paginate(session, statement) @@ -136,10 +154,11 @@ async def list_activity( response_model=DefaultLimitOffsetPage[ActivityTaskCommentFeedItemRead], ) async def list_task_comment_feed( - board_id: UUID | None = Query(default=None), - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_member), + board_id: UUID | None = BOARD_ID_QUERY, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> DefaultLimitOffsetPage[ActivityTaskCommentFeedItemRead]: + """List task-comment feed items for accessible boards.""" statement = ( select(ActivityEvent, Task, Board, Agent) .join(Task, col(ActivityEvent.task_id) == col(Task.id)) @@ -161,7 +180,10 @@ async def list_task_comment_feed( def _transform(items: Sequence[Any]) -> Sequence[Any]: rows = cast(Sequence[tuple[ActivityEvent, Task, Board, Agent | None]], items) - return [_feed_item(event, task, board, agent) for event, task, board, agent in rows] + return [ + _feed_item(event, task, board, agent) + for event, task, board, agent in rows + ] return await paginate(session, statement, transformer=_transform) @@ -169,13 +191,18 @@ async def list_task_comment_feed( @router.get("/task-comments/stream") async def stream_task_comment_feed( request: Request, - board_id: UUID | None = Query(default=None), - since: str | None = Query(default=None), - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_member), + board_id: UUID | None = BOARD_ID_QUERY, + since: str | None = SINCE_QUERY, + db_session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> EventSourceResponse: + """Stream task-comment events for accessible boards.""" since_dt = _parse_since(since) or utcnow() - board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) + board_ids = await list_accessible_board_ids( + db_session, + member=ctx.member, + write=False, + ) allowed_ids = set(board_ids) if board_id is not None and board_id not in allowed_ids: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) @@ -187,11 +214,15 @@ async def stream_task_comment_feed( while True: if await request.is_disconnected(): break - async with async_session_maker() as session: + async with async_session_maker() as stream_session: if board_id is not None: - rows = await _fetch_task_comment_events(session, last_seen, board_id=board_id) + rows = await _fetch_task_comment_events( + stream_session, + last_seen, + board_id=board_id, + ) elif allowed_ids: - rows = await _fetch_task_comment_events(session, last_seen) + rows = await _fetch_task_comment_events(stream_session, last_seen) rows = [row for row in rows if row[1].board_id in allowed_ids] else: rows = [] @@ -204,10 +235,16 @@ async def stream_task_comment_feed( if len(seen_queue) > SSE_SEEN_MAX: oldest = seen_queue.popleft() seen_ids.discard(oldest) - if event.created_at > last_seen: - last_seen = event.created_at - payload = {"comment": _feed_item(event, task, board, agent).model_dump(mode="json")} + last_seen = max(event.created_at, last_seen) + payload = { + "comment": _feed_item( + event, + task, + board, + agent, + ).model_dump(mode="json"), + } yield {"event": "comment", "data": json.dumps(payload)} - await asyncio.sleep(2) + await asyncio.sleep(STREAM_POLL_SECONDS) return EventSourceResponse(event_generator(), ping=15) diff --git a/backend/app/api/approvals.py b/backend/app/api/approvals.py index a98f207..83d28bd 100644 --- a/backend/app/api/approvals.py +++ b/backend/app/api/approvals.py @@ -1,15 +1,16 @@ +"""Approval listing, streaming, creation, and update endpoints.""" + from __future__ import annotations import asyncio import json -from collections.abc import AsyncIterator from datetime import datetime, timezone +from typing import TYPE_CHECKING from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from sqlalchemy import asc, case, func, or_ from sqlmodel import col, select -from sqlmodel.ext.asyncio.session import AsyncSession from sse_starlette.sse import EventSourceResponse from app.api.deps import ( @@ -23,13 +24,32 @@ from app.core.time import utcnow from app.db.pagination import paginate from app.db.session import async_session_maker, get_session from app.models.approvals import Approval -from app.models.boards import Board -from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus, ApprovalUpdate +from app.schemas.approvals import ( + ApprovalCreate, + ApprovalRead, + ApprovalStatus, + ApprovalUpdate, +) from app.schemas.pagination import DefaultLimitOffsetPage +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from sqlmodel.ext.asyncio.session import AsyncSession + + from app.models.boards import Board + router = APIRouter(prefix="/boards/{board_id}/approvals", tags=["approvals"]) TASK_ID_KEYS: tuple[str, ...] = ("task_id", "taskId", "taskID") +STREAM_POLL_SECONDS = 2 +STATUS_FILTER_QUERY = Query(default=None, alias="status") +SINCE_QUERY = Query(default=None) +BOARD_READ_DEP = Depends(get_board_for_actor_read) +BOARD_WRITE_DEP = Depends(get_board_for_actor_write) +BOARD_USER_WRITE_DEP = Depends(get_board_for_user_write) +SESSION_DEP = Depends(get_session) +ACTOR_DEP = Depends(require_admin_or_agent) def _extract_task_id(payload: dict[str, object] | None) -> UUID | None: @@ -68,7 +88,10 @@ def _approval_updated_at(approval: Approval) -> datetime: def _serialize_approval(approval: Approval) -> dict[str, object]: - return ApprovalRead.model_validate(approval, from_attributes=True).model_dump(mode="json") + return ApprovalRead.model_validate( + approval, + from_attributes=True, + ).model_dump(mode="json") async def _fetch_approval_events( @@ -82,7 +105,7 @@ async def _fetch_approval_events( or_( col(Approval.created_at) >= since, col(Approval.resolved_at) >= since, - ) + ), ) .order_by(asc(col(Approval.created_at))) ) @@ -91,11 +114,12 @@ async def _fetch_approval_events( @router.get("", response_model=DefaultLimitOffsetPage[ApprovalRead]) async def list_approvals( - status_filter: ApprovalStatus | None = Query(default=None, alias="status"), - board: Board = Depends(get_board_for_actor_read), - session: AsyncSession = Depends(get_session), - actor: ActorContext = Depends(require_admin_or_agent), + status_filter: ApprovalStatus | None = STATUS_FILTER_QUERY, + board: Board = BOARD_READ_DEP, + session: AsyncSession = SESSION_DEP, + _actor: ActorContext = ACTOR_DEP, ) -> DefaultLimitOffsetPage[ApprovalRead]: + """List approvals for a board, optionally filtering by status.""" statement = Approval.objects.filter_by(board_id=board.id) if status_filter: statement = statement.filter(col(Approval.status) == status_filter) @@ -106,10 +130,11 @@ async def list_approvals( @router.get("/stream") async def stream_approvals( request: Request, - board: Board = Depends(get_board_for_actor_read), - actor: ActorContext = Depends(require_admin_or_agent), - since: str | None = Query(default=None), + board: Board = BOARD_READ_DEP, + _actor: ActorContext = ACTOR_DEP, + since: str | None = SINCE_QUERY, ) -> EventSourceResponse: + """Stream approval updates for a board using server-sent events.""" since_dt = _parse_since(since) or utcnow() last_seen = since_dt @@ -125,12 +150,14 @@ async def stream_approvals( await session.exec( select(func.count(col(Approval.id))) .where(col(Approval.board_id) == board.id) - .where(col(Approval.status) == "pending") + .where(col(Approval.status) == "pending"), ) - ).one() + ).one(), ) task_ids = { - approval.task_id for approval in approvals if approval.task_id is not None + approval.task_id + for approval in approvals + if approval.task_id is not None } counts_by_task_id: dict[UUID, tuple[int, int]] = {} if task_ids: @@ -140,22 +167,27 @@ async def stream_approvals( col(Approval.task_id), func.count(col(Approval.id)).label("total"), func.sum( - case((col(Approval.status) == "pending", 1), else_=0) + case( + (col(Approval.status) == "pending", 1), + else_=0, + ), ).label("pending"), ) .where(col(Approval.board_id) == board.id) .where(col(Approval.task_id).in_(task_ids)) - .group_by(col(Approval.task_id)) - ) + .group_by(col(Approval.task_id)), + ), ) for task_id, total, pending in rows: if task_id is None: continue - counts_by_task_id[task_id] = (int(total or 0), int(pending or 0)) + counts_by_task_id[task_id] = ( + int(total or 0), + int(pending or 0), + ) for approval in approvals: updated_at = _approval_updated_at(approval) - if updated_at > last_seen: - last_seen = updated_at + last_seen = max(updated_at, last_seen) payload: dict[str, object] = { "approval": _serialize_approval(approval), "pending_approvals_count": pending_approvals_count, @@ -170,7 +202,7 @@ async def stream_approvals( "approvals_pending_count": pending, } yield {"event": "approval", "data": json.dumps(payload)} - await asyncio.sleep(2) + await asyncio.sleep(STREAM_POLL_SECONDS) return EventSourceResponse(event_generator(), ping=15) @@ -178,10 +210,11 @@ async def stream_approvals( @router.post("", response_model=ApprovalRead) async def create_approval( payload: ApprovalCreate, - board: Board = Depends(get_board_for_actor_write), - session: AsyncSession = Depends(get_session), - actor: ActorContext = Depends(require_admin_or_agent), + board: Board = BOARD_WRITE_DEP, + session: AsyncSession = SESSION_DEP, + _actor: ActorContext = ACTOR_DEP, ) -> Approval: + """Create an approval for a board.""" task_id = payload.task_id or _extract_task_id(payload.payload) approval = Approval( board_id=board.id, @@ -203,9 +236,10 @@ async def create_approval( async def update_approval( approval_id: str, payload: ApprovalUpdate, - board: Board = Depends(get_board_for_user_write), - session: AsyncSession = Depends(get_session), + board: Board = BOARD_USER_WRITE_DEP, + session: AsyncSession = SESSION_DEP, ) -> Approval: + """Update an approval's status and resolution timestamp.""" approval = await Approval.objects.by_id(approval_id).first(session) if approval is None or approval.board_id != board.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) diff --git a/backend/app/api/board_group_memory.py b/backend/app/api/board_group_memory.py index ab1aade..f1b57f9 100644 --- a/backend/app/api/board_group_memory.py +++ b/backend/app/api/board_group_memory.py @@ -1,15 +1,17 @@ +"""Board-group memory CRUD and streaming endpoints.""" + from __future__ import annotations import asyncio import json -from collections.abc import AsyncIterator +from dataclasses import dataclass from datetime import datetime, timezone +from typing import TYPE_CHECKING from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from sqlalchemy import func from sqlmodel import col -from sqlmodel.ext.asyncio.session import AsyncSession from sse_starlette.sse import EventSourceResponse from app.api.deps import ( @@ -24,28 +26,56 @@ from app.core.time import utcnow from app.db.pagination import paginate from app.db.session import async_session_maker, get_session from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig -from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message +from app.integrations.openclaw_gateway import ( + OpenClawGatewayError, + ensure_session, + send_message, +) from app.models.agents import Agent from app.models.board_group_memory import BoardGroupMemory from app.models.board_groups import BoardGroup from app.models.boards import Board from app.models.gateways import Gateway from app.models.users import User -from app.schemas.board_group_memory import BoardGroupMemoryCreate, BoardGroupMemoryRead +from app.schemas.board_group_memory import ( + BoardGroupMemoryCreate, + BoardGroupMemoryRead, +) from app.schemas.pagination import DefaultLimitOffsetPage from app.services.mentions import extract_mentions, matches_agent_mention from app.services.organizations import ( - OrganizationContext, is_org_admin, list_accessible_board_ids, member_all_boards_read, member_all_boards_write, ) -router = APIRouter(tags=["board-group-memory"]) +if TYPE_CHECKING: + from collections.abc import AsyncIterator -group_router = APIRouter(prefix="/board-groups/{group_id}/memory", tags=["board-group-memory"]) -board_router = APIRouter(prefix="/boards/{board_id}/group-memory", tags=["board-group-memory"]) + from sqlmodel.ext.asyncio.session import AsyncSession + + from app.services.organizations import OrganizationContext + +router = APIRouter(tags=["board-group-memory"]) +group_router = APIRouter( + prefix="/board-groups/{group_id}/memory", + tags=["board-group-memory"], +) +board_router = APIRouter( + prefix="/boards/{board_id}/group-memory", + tags=["board-group-memory"], +) +MAX_SNIPPET_LENGTH = 800 +STREAM_POLL_SECONDS = 2 +SESSION_DEP = Depends(get_session) +ORG_MEMBER_DEP = Depends(require_org_member) +BOARD_READ_DEP = Depends(get_board_for_actor_read) +BOARD_WRITE_DEP = Depends(get_board_for_actor_write) +ACTOR_DEP = Depends(require_admin_or_agent) +IS_CHAT_QUERY = Query(default=None) +SINCE_QUERY = Query(default=None) +_RUNTIME_TYPE_REFERENCES = (UUID,) def _parse_since(value: str | None) -> datetime | None: @@ -65,10 +95,16 @@ def _parse_since(value: str | None) -> datetime | None: def _serialize_memory(memory: BoardGroupMemory) -> dict[str, object]: - return BoardGroupMemoryRead.model_validate(memory, from_attributes=True).model_dump(mode="json") + return BoardGroupMemoryRead.model_validate( + memory, + from_attributes=True, + ).model_dump(mode="json") -async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None: +async def _gateway_config( + session: AsyncSession, + board: Board, +) -> GatewayClientConfig | None: if board.gateway_id is None: return None gateway = await Gateway.objects.by_id(board.gateway_id).first(session) @@ -104,7 +140,7 @@ async def _fetch_memory_events( if is_chat is not None: statement = statement.filter(col(BoardGroupMemory.is_chat) == is_chat) statement = statement.filter(col(BoardGroupMemory.created_at) >= since).order_by( - col(BoardGroupMemory.created_at) + col(BoardGroupMemory.created_at), ) return await statement.all(session) @@ -128,19 +164,124 @@ async def _require_group_access( return group board_ids = [ - board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session) + board.id + for board in await Board.objects.filter_by(board_group_id=group_id).all( + session, + ) ] if not board_ids: if is_org_admin(ctx.member): return group raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - allowed_ids = await list_accessible_board_ids(session, member=ctx.member, write=write) + allowed_ids = await list_accessible_board_ids( + session, + member=ctx.member, + write=write, + ) if not set(board_ids).intersection(set(allowed_ids)): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return group +async def _group_read_access( + group_id: UUID, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_MEMBER_DEP, +) -> BoardGroup: + return await _require_group_access(session, group_id=group_id, ctx=ctx, write=False) + + +GROUP_READ_DEP = Depends(_group_read_access) + + +def _group_chat_targets( + *, + agents: list[Agent], + actor: ActorContext, + is_broadcast: bool, + mentions: set[str], +) -> dict[str, Agent]: + targets: dict[str, Agent] = {} + for agent in agents: + if not agent.openclaw_session_id: + continue + if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id: + continue + if is_broadcast or agent.is_board_lead: + targets[str(agent.id)] = agent + continue + if mentions and matches_agent_mention(agent, mentions): + targets[str(agent.id)] = agent + return targets + + +def _group_actor_name(actor: ActorContext) -> str: + if actor.actor_type == "agent" and actor.agent: + return actor.agent.name + if actor.user: + return actor.user.preferred_name or actor.user.name or "User" + return "User" + + +def _group_header(*, is_broadcast: bool, mentioned: bool) -> str: + if is_broadcast: + return "GROUP BROADCAST" + if mentioned: + return "GROUP CHAT MENTION" + return "GROUP CHAT" + + +@dataclass(frozen=True) +class _NotifyGroupContext: + session: AsyncSession + group: BoardGroup + board_by_id: dict[UUID, Board] + mentions: set[str] + is_broadcast: bool + actor_name: str + snippet: str + base_url: str + + +async def _notify_group_target( + context: _NotifyGroupContext, + agent: Agent, +) -> None: + session_key = agent.openclaw_session_id + board_id = agent.board_id + if not session_key or board_id is None: + return + board = context.board_by_id.get(board_id) + if board is None: + return + config = await _gateway_config(context.session, board) + if config is None: + return + header = _group_header( + is_broadcast=context.is_broadcast, + mentioned=matches_agent_mention(agent, context.mentions), + ) + message = ( + f"{header}\n" + f"Group: {context.group.name}\n" + f"From: {context.actor_name}\n\n" + f"{context.snippet}\n\n" + "Reply via group chat (shared across linked boards):\n" + f"POST {context.base_url}/api/v1/boards/{board.id}/group-memory\n" + 'Body: {"content":"...","tags":["chat"]}' + ) + try: + await _send_agent_message( + session_key=session_key, + config=config, + agent_name=agent.name, + message=message, + ) + except OpenClawGatewayError: + return + + async def _notify_group_memory_targets( *, session: AsyncSession, @@ -163,83 +304,47 @@ async def _notify_group_memory_targets( board_ids = list(board_by_id.keys()) agents = await Agent.objects.by_field_in("board_id", board_ids).all(session) - targets: dict[str, Agent] = {} - for agent in agents: - if not agent.openclaw_session_id: - continue - if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id: - continue - if is_broadcast: - targets[str(agent.id)] = agent - continue - if agent.is_board_lead: - targets[str(agent.id)] = agent - continue - if mentions and matches_agent_mention(agent, mentions): - targets[str(agent.id)] = agent + targets = _group_chat_targets( + agents=agents, + actor=actor, + is_broadcast=is_broadcast, + mentions=mentions, + ) if not targets: return - actor_name = "User" - if actor.actor_type == "agent" and actor.agent: - actor_name = actor.agent.name - elif actor.user: - actor_name = actor.user.preferred_name or actor.user.name or actor_name + actor_name = _group_actor_name(actor) snippet = memory.content.strip() - if len(snippet) > 800: - snippet = f"{snippet[:797]}..." + if len(snippet) > MAX_SNIPPET_LENGTH: + snippet = f"{snippet[: MAX_SNIPPET_LENGTH - 3]}..." base_url = settings.base_url or "http://localhost:8000" + context = _NotifyGroupContext( + session=session, + group=group, + board_by_id=board_by_id, + mentions=mentions, + is_broadcast=is_broadcast, + actor_name=actor_name, + snippet=snippet, + base_url=base_url, + ) for agent in targets.values(): - session_key = agent.openclaw_session_id - if not session_key: - continue - board_id = agent.board_id - if board_id is None: - continue - board = board_by_id.get(board_id) - if board is None: - continue - config = await _gateway_config(session, board) - if config is None: - continue - mentioned = matches_agent_mention(agent, mentions) - if is_broadcast: - header = "GROUP BROADCAST" - elif mentioned: - header = "GROUP CHAT MENTION" - else: - header = "GROUP CHAT" - message = ( - f"{header}\n" - f"Group: {group.name}\n" - f"From: {actor_name}\n\n" - f"{snippet}\n\n" - "Reply via group chat (shared across linked boards):\n" - f"POST {base_url}/api/v1/boards/{board.id}/group-memory\n" - 'Body: {"content":"...","tags":["chat"]}' - ) - try: - await _send_agent_message( - session_key=session_key, - config=config, - agent_name=agent.name, - message=message, - ) - except OpenClawGatewayError: - continue + await _notify_group_target(context, agent) @group_router.get("", response_model=DefaultLimitOffsetPage[BoardGroupMemoryRead]) async def list_board_group_memory( group_id: UUID, - is_chat: bool | None = Query(default=None), - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_member), + *, + is_chat: bool | None = IS_CHAT_QUERY, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]: + """List board-group memory entries for a specific group.""" await _require_group_access(session, group_id=group_id, ctx=ctx, write=False) statement = ( BoardGroupMemory.objects.filter_by(board_group_id=group_id) @@ -255,14 +360,13 @@ async def list_board_group_memory( @group_router.get("/stream") async def stream_board_group_memory( - group_id: UUID, request: Request, - since: str | None = Query(default=None), - is_chat: bool | None = Query(default=None), - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_member), + group: BoardGroup = GROUP_READ_DEP, + *, + since: str | None = SINCE_QUERY, + is_chat: bool | None = IS_CHAT_QUERY, ) -> EventSourceResponse: - await _require_group_access(session, group_id=group_id, ctx=ctx, write=False) + """Stream memory entries for a board group via server-sent events.""" since_dt = _parse_since(since) or utcnow() last_seen = since_dt @@ -274,16 +378,15 @@ async def stream_board_group_memory( async with async_session_maker() as s: memories = await _fetch_memory_events( s, - group_id, + group.id, last_seen, is_chat=is_chat, ) for memory in memories: - if memory.created_at > last_seen: - last_seen = memory.created_at + last_seen = max(memory.created_at, last_seen) payload = {"memory": _serialize_memory(memory)} yield {"event": "memory", "data": json.dumps(payload)} - await asyncio.sleep(2) + await asyncio.sleep(STREAM_POLL_SECONDS) return EventSourceResponse(event_generator(), ping=15) @@ -292,9 +395,10 @@ async def stream_board_group_memory( async def create_board_group_memory( group_id: UUID, payload: BoardGroupMemoryCreate, - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_member), + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> BoardGroupMemory: + """Create a board-group memory entry and notify chat recipients.""" group = await _require_group_access(session, group_id=group_id, ctx=ctx, write=True) user = await User.objects.by_id(ctx.member.user_id).first(session) @@ -320,16 +424,23 @@ async def create_board_group_memory( await session.commit() await session.refresh(memory) if should_notify: - await _notify_group_memory_targets(session=session, group=group, memory=memory, actor=actor) + await _notify_group_memory_targets( + session=session, + group=group, + memory=memory, + actor=actor, + ) return memory @board_router.get("", response_model=DefaultLimitOffsetPage[BoardGroupMemoryRead]) async def list_board_group_memory_for_board( - is_chat: bool | None = Query(default=None), - board: Board = Depends(get_board_for_actor_read), - session: AsyncSession = Depends(get_session), + *, + is_chat: bool | None = IS_CHAT_QUERY, + board: Board = BOARD_READ_DEP, + session: AsyncSession = SESSION_DEP, ) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]: + """List memory entries for the board's linked group.""" group_id = board.board_group_id if group_id is None: return await paginate(session, BoardGroupMemory.objects.by_ids([]).statement) @@ -349,10 +460,12 @@ async def list_board_group_memory_for_board( @board_router.get("/stream") async def stream_board_group_memory_for_board( request: Request, - board: Board = Depends(get_board_for_actor_read), - since: str | None = Query(default=None), - is_chat: bool | None = Query(default=None), + *, + board: Board = BOARD_READ_DEP, + since: str | None = SINCE_QUERY, + is_chat: bool | None = IS_CHAT_QUERY, ) -> EventSourceResponse: + """Stream memory entries for the board's linked group.""" group_id = board.board_group_id since_dt = _parse_since(since) or utcnow() last_seen = since_dt @@ -373,11 +486,10 @@ async def stream_board_group_memory_for_board( is_chat=is_chat, ) for memory in memories: - if memory.created_at > last_seen: - last_seen = memory.created_at + last_seen = max(memory.created_at, last_seen) payload = {"memory": _serialize_memory(memory)} yield {"event": "memory", "data": json.dumps(payload)} - await asyncio.sleep(2) + await asyncio.sleep(STREAM_POLL_SECONDS) return EventSourceResponse(event_generator(), ping=15) @@ -385,10 +497,11 @@ async def stream_board_group_memory_for_board( @board_router.post("", response_model=BoardGroupMemoryRead) async def create_board_group_memory_for_board( payload: BoardGroupMemoryCreate, - board: Board = Depends(get_board_for_actor_write), - session: AsyncSession = Depends(get_session), - actor: ActorContext = Depends(require_admin_or_agent), + board: Board = BOARD_WRITE_DEP, + session: AsyncSession = SESSION_DEP, + actor: ActorContext = ACTOR_DEP, ) -> BoardGroupMemory: + """Create a group memory entry from a board context and notify recipients.""" group_id = board.board_group_id if group_id is None: raise HTTPException( @@ -420,7 +533,12 @@ async def create_board_group_memory_for_board( await session.commit() await session.refresh(memory) if should_notify: - await _notify_group_memory_targets(session=session, group=group, memory=memory, actor=actor) + await _notify_group_memory_targets( + session=session, + group=group, + memory=memory, + actor=actor, + ) return memory diff --git a/backend/app/api/board_groups.py b/backend/app/api/board_groups.py index cfabf43..063281f 100644 --- a/backend/app/api/board_groups.py +++ b/backend/app/api/board_groups.py @@ -1,15 +1,21 @@ +"""Board group CRUD, snapshot, and heartbeat endpoints.""" + from __future__ import annotations import re -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import func from sqlmodel import col, select -from sqlmodel.ext.asyncio.session import AsyncSession -from app.api.deps import ActorContext, require_admin_or_agent, require_org_admin, require_org_member +from app.api.deps import ( + ActorContext, + require_admin_or_agent, + require_org_admin, + require_org_member, +) from app.core.time import utcnow from app.db import crud from app.db.pagination import paginate @@ -20,7 +26,6 @@ from app.models.board_group_memory import BoardGroupMemory from app.models.board_groups import BoardGroup from app.models.boards import Board from app.models.gateways import Gateway -from app.models.organization_members import OrganizationMember from app.schemas.board_group_heartbeat import ( BoardGroupHeartbeatApply, BoardGroupHeartbeatApplyResult, @@ -29,7 +34,10 @@ from app.schemas.board_groups import BoardGroupCreate, BoardGroupRead, BoardGrou from app.schemas.common import OkResponse from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.view_models import BoardGroupSnapshot -from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, sync_gateway_agent_heartbeats +from app.services.agent_provisioning import ( + DEFAULT_HEARTBEAT_CONFIG, + sync_gateway_agent_heartbeats, +) from app.services.board_group_snapshot import build_group_snapshot from app.services.organizations import ( OrganizationContext, @@ -41,7 +49,16 @@ from app.services.organizations import ( member_all_boards_write, ) +if TYPE_CHECKING: + from sqlmodel.ext.asyncio.session import AsyncSession + + from app.models.organization_members import OrganizationMember + router = APIRouter(prefix="/board-groups", tags=["board-groups"]) +SESSION_DEP = Depends(get_session) +ORG_MEMBER_DEP = Depends(require_org_member) +ORG_ADMIN_DEP = Depends(require_org_admin) +ACTOR_DEP = Depends(require_admin_or_agent) def _slugify(value: str) -> str: @@ -68,7 +85,8 @@ async def _require_group_access( return group board_ids = [ - board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session) + board.id + for board in await Board.objects.filter_by(board_group_id=group_id).all(session) ] if not board_ids: if is_org_admin(member): @@ -83,14 +101,17 @@ async def _require_group_access( @router.get("", response_model=DefaultLimitOffsetPage[BoardGroupRead]) async def list_board_groups( - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_member), + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> DefaultLimitOffsetPage[BoardGroupRead]: + """List board groups in the active organization.""" if member_all_boards_read(ctx.member): - statement = select(BoardGroup).where(col(BoardGroup.organization_id) == ctx.organization.id) + statement = select(BoardGroup).where( + col(BoardGroup.organization_id) == ctx.organization.id, + ) else: accessible_boards = select(Board.board_group_id).where( - board_access_filter(ctx.member, write=False) + board_access_filter(ctx.member, write=False), ) statement = select(BoardGroup).where( col(BoardGroup.organization_id) == ctx.organization.id, @@ -103,9 +124,10 @@ async def list_board_groups( @router.post("", response_model=BoardGroupRead) async def create_board_group( payload: BoardGroupCreate, - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_admin), + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> BoardGroup: + """Create a board group in the active organization.""" data = payload.model_dump() if not (data.get("slug") or "").strip(): data["slug"] = _slugify(data.get("name") or "") @@ -116,21 +138,28 @@ async def create_board_group( @router.get("/{group_id}", response_model=BoardGroupRead) async def get_board_group( group_id: UUID, - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_member), + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> BoardGroup: - return await _require_group_access(session, group_id=group_id, member=ctx.member, write=False) + """Get a board group by id.""" + return await _require_group_access( + session, group_id=group_id, member=ctx.member, write=False, + ) @router.get("/{group_id}/snapshot", response_model=BoardGroupSnapshot) async def get_board_group_snapshot( group_id: UUID, + *, include_done: bool = False, per_board_task_limit: int = 5, - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_member), + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> BoardGroupSnapshot: - group = await _require_group_access(session, group_id=group_id, member=ctx.member, write=False) + """Get a snapshot across boards in a group.""" + group = await _require_group_access( + session, group_id=group_id, member=ctx.member, write=False, + ) if per_board_task_limit < 0: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) snapshot = await build_group_snapshot( @@ -141,22 +170,22 @@ async def get_board_group_snapshot( per_board_task_limit=per_board_task_limit, ) if not member_all_boards_read(ctx.member) and snapshot.boards: - allowed_ids = set(await list_accessible_board_ids(session, member=ctx.member, write=False)) - snapshot.boards = [item for item in snapshot.boards if item.board.id in allowed_ids] + allowed_ids = set( + await list_accessible_board_ids(session, member=ctx.member, write=False), + ) + snapshot.boards = [ + item for item in snapshot.boards if item.board.id in allowed_ids + ] return snapshot -@router.post("/{group_id}/heartbeat", response_model=BoardGroupHeartbeatApplyResult) -async def apply_board_group_heartbeat( +async def _authorize_heartbeat_actor( + session: AsyncSession, + *, group_id: UUID, - payload: BoardGroupHeartbeatApply, - session: AsyncSession = Depends(get_session), - actor: ActorContext = Depends(require_admin_or_agent), -) -> BoardGroupHeartbeatApplyResult: - group = await BoardGroup.objects.by_id(group_id).first(session) - if group is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) - + group: BoardGroup, + actor: ActorContext, +) -> None: if actor.actor_type == "user": if actor.user is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) @@ -173,53 +202,58 @@ async def apply_board_group_heartbeat( member=member, write=True, ) - elif actor.actor_type == "agent": - agent = actor.agent - if agent is None: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - if agent.board_id is None: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - if not agent.is_board_lead: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - board = await Board.objects.by_id(agent.board_id).first(session) - if board is None or board.board_group_id != group_id: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + return + agent = actor.agent + if agent is None or agent.board_id is None or not agent.is_board_lead: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + board = await Board.objects.by_id(agent.board_id).first(session) + if board is None or board.board_group_id != group_id: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + +async def _agents_for_group_heartbeat( + session: AsyncSession, + *, + group_id: UUID, + include_board_leads: bool, +) -> tuple[dict[UUID, Board], list[Agent]]: boards = await Board.objects.filter_by(board_group_id=group_id).all(session) board_by_id = {board.id: board for board in boards} board_ids = list(board_by_id.keys()) if not board_ids: - return BoardGroupHeartbeatApplyResult( - board_group_id=group_id, - requested=payload.model_dump(mode="json"), - updated_agent_ids=[], - failed_agent_ids=[], - ) - + return board_by_id, [] agents = await Agent.objects.by_field_in("board_id", board_ids).all(session) - if not payload.include_board_leads: + if not include_board_leads: agents = [agent for agent in agents if not agent.is_board_lead] + return board_by_id, agents - updated_agent_ids: list[UUID] = [] - for agent in agents: - raw = agent.heartbeat_config - heartbeat: dict[str, Any] = ( - cast(dict[str, Any], dict(raw)) - if isinstance(raw, dict) - else cast(dict[str, Any], DEFAULT_HEARTBEAT_CONFIG.copy()) - ) - heartbeat["every"] = payload.every - if payload.target is not None: - heartbeat["target"] = payload.target - elif "target" not in heartbeat: - heartbeat["target"] = DEFAULT_HEARTBEAT_CONFIG.get("target", "none") - agent.heartbeat_config = heartbeat - agent.updated_at = utcnow() - session.add(agent) - updated_agent_ids.append(agent.id) - await session.commit() +def _update_agent_heartbeat( + *, + agent: Agent, + payload: BoardGroupHeartbeatApply, +) -> None: + raw = agent.heartbeat_config + heartbeat: dict[str, Any] = ( + cast(dict[str, Any], dict(raw)) + if isinstance(raw, dict) + else cast(dict[str, Any], DEFAULT_HEARTBEAT_CONFIG.copy()) + ) + heartbeat["every"] = payload.every + if payload.target is not None: + heartbeat["target"] = payload.target + elif "target" not in heartbeat: + heartbeat["target"] = DEFAULT_HEARTBEAT_CONFIG.get("target", "none") + agent.heartbeat_config = heartbeat + agent.updated_at = utcnow() + +async def _sync_gateway_heartbeats( + session: AsyncSession, + *, + board_by_id: dict[UUID, Board], + agents: list[Agent], +) -> list[UUID]: agents_by_gateway_id: dict[UUID, list[Agent]] = {} for agent in agents: board_id = agent.board_id @@ -243,6 +277,51 @@ async def apply_board_group_heartbeat( await sync_gateway_agent_heartbeats(gateway, gateway_agents) except OpenClawGatewayError: failed_agent_ids.extend([agent.id for agent in gateway_agents]) + return failed_agent_ids + + +@router.post("/{group_id}/heartbeat", response_model=BoardGroupHeartbeatApplyResult) +async def apply_board_group_heartbeat( + group_id: UUID, + payload: BoardGroupHeartbeatApply, + session: AsyncSession = SESSION_DEP, + actor: ActorContext = ACTOR_DEP, +) -> BoardGroupHeartbeatApplyResult: + """Apply heartbeat settings to agents in a board group.""" + group = await BoardGroup.objects.by_id(group_id).first(session) + if group is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + await _authorize_heartbeat_actor( + session, + group_id=group_id, + group=group, + actor=actor, + ) + board_by_id, agents = await _agents_for_group_heartbeat( + session, + group_id=group_id, + include_board_leads=payload.include_board_leads, + ) + if not agents: + return BoardGroupHeartbeatApplyResult( + board_group_id=group_id, + requested=payload.model_dump(mode="json"), + updated_agent_ids=[], + failed_agent_ids=[], + ) + + updated_agent_ids: list[UUID] = [] + for agent in agents: + _update_agent_heartbeat(agent=agent, payload=payload) + session.add(agent) + updated_agent_ids.append(agent.id) + + await session.commit() + failed_agent_ids = await _sync_gateway_heartbeats( + session, + board_by_id=board_by_id, + agents=agents, + ) return BoardGroupHeartbeatApplyResult( board_group_id=group_id, @@ -256,12 +335,19 @@ async def apply_board_group_heartbeat( async def update_board_group( payload: BoardGroupUpdate, group_id: UUID, - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_admin), + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> BoardGroup: - group = await _require_group_access(session, group_id=group_id, member=ctx.member, write=True) + """Update a board group.""" + group = await _require_group_access( + session, group_id=group_id, member=ctx.member, write=True, + ) updates = payload.model_dump(exclude_unset=True) - if "slug" in updates and updates["slug"] is not None and not updates["slug"].strip(): + if ( + "slug" in updates + and updates["slug"] is not None + and not updates["slug"].strip() + ): updates["slug"] = _slugify(updates.get("name") or group.name) updates["updated_at"] = utcnow() return await crud.patch(session, group, updates) @@ -270,10 +356,13 @@ async def update_board_group( @router.delete("/{group_id}", response_model=OkResponse) async def delete_board_group( group_id: UUID, - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_admin), + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> OkResponse: - await _require_group_access(session, group_id=group_id, member=ctx.member, write=True) + """Delete a board group.""" + await _require_group_access( + session, group_id=group_id, member=ctx.member, write=True, + ) # Boards reference groups, so clear the FK first to keep deletes simple. await crud.update_where( @@ -284,8 +373,13 @@ async def delete_board_group( commit=False, ) await crud.delete_where( - session, BoardGroupMemory, col(BoardGroupMemory.board_group_id) == group_id, commit=False + session, + BoardGroupMemory, + col(BoardGroupMemory.board_group_id) == group_id, + commit=False, + ) + await crud.delete_where( + session, BoardGroup, col(BoardGroup.id) == group_id, commit=False, ) - await crud.delete_where(session, BoardGroup, col(BoardGroup.id) == group_id, commit=False) await session.commit() return OkResponse() diff --git a/backend/app/api/board_memory.py b/backend/app/api/board_memory.py index dd277fe..a288987 100644 --- a/backend/app/api/board_memory.py +++ b/backend/app/api/board_memory.py @@ -1,15 +1,16 @@ +"""Board memory CRUD and streaming endpoints.""" + from __future__ import annotations import asyncio import json -from collections.abc import AsyncIterator from datetime import datetime, timezone +from typing import TYPE_CHECKING from uuid import UUID from fastapi import APIRouter, Depends, Query, Request from sqlalchemy import func from sqlmodel import col -from sqlmodel.ext.asyncio.session import AsyncSession from sse_starlette.sse import EventSourceResponse from app.api.deps import ( @@ -23,16 +24,35 @@ from app.core.time import utcnow from app.db.pagination import paginate from app.db.session import async_session_maker, get_session from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig -from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message +from app.integrations.openclaw_gateway import ( + OpenClawGatewayError, + ensure_session, + send_message, +) from app.models.agents import Agent from app.models.board_memory import BoardMemory -from app.models.boards import Board from app.models.gateways import Gateway from app.schemas.board_memory import BoardMemoryCreate, BoardMemoryRead from app.schemas.pagination import DefaultLimitOffsetPage from app.services.mentions import extract_mentions, matches_agent_mention +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from sqlmodel.ext.asyncio.session import AsyncSession + + from app.models.boards import Board + router = APIRouter(prefix="/boards/{board_id}/memory", tags=["board-memory"]) +MAX_SNIPPET_LENGTH = 800 +STREAM_POLL_SECONDS = 2 +IS_CHAT_QUERY = Query(default=None) +SINCE_QUERY = Query(default=None) +BOARD_READ_DEP = Depends(get_board_for_actor_read) +BOARD_WRITE_DEP = Depends(get_board_for_actor_write) +SESSION_DEP = Depends(get_session) +ACTOR_DEP = Depends(require_admin_or_agent) +_RUNTIME_TYPE_REFERENCES = (UUID,) def _parse_since(value: str | None) -> datetime | None: @@ -52,10 +72,16 @@ def _parse_since(value: str | None) -> datetime | None: def _serialize_memory(memory: BoardMemory) -> dict[str, object]: - return BoardMemoryRead.model_validate(memory, from_attributes=True).model_dump(mode="json") + return BoardMemoryRead.model_validate( + memory, + from_attributes=True, + ).model_dump(mode="json") -async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None: +async def _gateway_config( + session: AsyncSession, + board: Board, +) -> GatewayClientConfig | None: if board.gateway_id is None: return None gateway = await Gateway.objects.by_id(board.gateway_id).first(session) @@ -91,11 +117,67 @@ async def _fetch_memory_events( if is_chat is not None: statement = statement.filter(col(BoardMemory.is_chat) == is_chat) statement = statement.filter(col(BoardMemory.created_at) >= since).order_by( - col(BoardMemory.created_at) + col(BoardMemory.created_at), ) return await statement.all(session) +async def _send_control_command( + *, + session: AsyncSession, + board: Board, + actor: ActorContext, + config: GatewayClientConfig, + command: str, +) -> None: + pause_targets: list[Agent] = await Agent.objects.filter_by( + board_id=board.id, + ).all( + session, + ) + for agent in pause_targets: + if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id: + continue + if not agent.openclaw_session_id: + continue + try: + await _send_agent_message( + session_key=agent.openclaw_session_id, + config=config, + agent_name=agent.name, + message=command, + deliver=True, + ) + except OpenClawGatewayError: + continue + + +def _chat_targets( + *, + agents: list[Agent], + mentions: set[str], + actor: ActorContext, +) -> dict[str, Agent]: + targets: dict[str, Agent] = {} + for agent in agents: + if agent.is_board_lead: + targets[str(agent.id)] = agent + continue + if mentions and matches_agent_mention(agent, mentions): + targets[str(agent.id)] = agent + if actor.actor_type == "agent" and actor.agent: + targets.pop(str(actor.agent.id), None) + return targets + + +def _actor_display_name(actor: ActorContext) -> str: + if actor.actor_type == "agent" and actor.agent: + return actor.agent.name + if actor.user: + return actor.user.preferred_name or actor.user.name or "User" + return "User" + + async def _notify_chat_targets( *, session: AsyncSession, @@ -114,44 +196,27 @@ async def _notify_chat_targets( # Special-case control commands to reach all board agents. # These are intended to be parsed verbatim by agent runtimes. if command in {"/pause", "/resume"}: - pause_targets: list[Agent] = await Agent.objects.filter_by(board_id=board.id).all(session) - for agent in pause_targets: - if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id: - continue - if not agent.openclaw_session_id: - continue - try: - await _send_agent_message( - session_key=agent.openclaw_session_id, - config=config, - agent_name=agent.name, - message=command, - deliver=True, - ) - except OpenClawGatewayError: - continue + await _send_control_command( + session=session, + board=board, + actor=actor, + config=config, + command=command, + ) return mentions = extract_mentions(memory.content) - targets: dict[str, Agent] = {} - for agent in await Agent.objects.filter_by(board_id=board.id).all(session): - if agent.is_board_lead: - targets[str(agent.id)] = agent - continue - if mentions and matches_agent_mention(agent, mentions): - targets[str(agent.id)] = agent - if actor.actor_type == "agent" and actor.agent: - targets.pop(str(actor.agent.id), None) + targets = _chat_targets( + agents=await Agent.objects.filter_by(board_id=board.id).all(session), + mentions=mentions, + actor=actor, + ) if not targets: return - actor_name = "User" - if actor.actor_type == "agent" and actor.agent: - actor_name = actor.agent.name - elif actor.user: - actor_name = actor.user.preferred_name or actor.user.name or actor_name + actor_name = _actor_display_name(actor) snippet = memory.content.strip() - if len(snippet) > 800: - snippet = f"{snippet[:797]}..." + if len(snippet) > MAX_SNIPPET_LENGTH: + snippet = f"{snippet[: MAX_SNIPPET_LENGTH - 3]}..." base_url = settings.base_url or "http://localhost:8000" for agent in targets.values(): if not agent.openclaw_session_id: @@ -180,11 +245,13 @@ async def _notify_chat_targets( @router.get("", response_model=DefaultLimitOffsetPage[BoardMemoryRead]) async def list_board_memory( - is_chat: bool | None = Query(default=None), - board: Board = Depends(get_board_for_actor_read), - session: AsyncSession = Depends(get_session), - actor: ActorContext = Depends(require_admin_or_agent), + *, + is_chat: bool | None = IS_CHAT_QUERY, + board: Board = BOARD_READ_DEP, + session: AsyncSession = SESSION_DEP, + _actor: ActorContext = ACTOR_DEP, ) -> DefaultLimitOffsetPage[BoardMemoryRead]: + """List board memory entries, optionally filtering chat entries.""" statement = ( BoardMemory.objects.filter_by(board_id=board.id) # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to @@ -200,11 +267,13 @@ async def list_board_memory( @router.get("/stream") async def stream_board_memory( request: Request, - board: Board = Depends(get_board_for_actor_read), - actor: ActorContext = Depends(require_admin_or_agent), - since: str | None = Query(default=None), - is_chat: bool | None = Query(default=None), + *, + board: Board = BOARD_READ_DEP, + _actor: ActorContext = ACTOR_DEP, + since: str | None = SINCE_QUERY, + is_chat: bool | None = IS_CHAT_QUERY, ) -> EventSourceResponse: + """Stream board memory events over server-sent events.""" since_dt = _parse_since(since) or utcnow() last_seen = since_dt @@ -221,11 +290,10 @@ async def stream_board_memory( is_chat=is_chat, ) for memory in memories: - if memory.created_at > last_seen: - last_seen = memory.created_at + last_seen = max(memory.created_at, last_seen) payload = {"memory": _serialize_memory(memory)} yield {"event": "memory", "data": json.dumps(payload)} - await asyncio.sleep(2) + await asyncio.sleep(STREAM_POLL_SECONDS) return EventSourceResponse(event_generator(), ping=15) @@ -233,10 +301,11 @@ async def stream_board_memory( @router.post("", response_model=BoardMemoryRead) async def create_board_memory( payload: BoardMemoryCreate, - board: Board = Depends(get_board_for_actor_write), - session: AsyncSession = Depends(get_session), - actor: ActorContext = Depends(require_admin_or_agent), + board: Board = BOARD_WRITE_DEP, + session: AsyncSession = SESSION_DEP, + actor: ActorContext = ACTOR_DEP, ) -> BoardMemory: + """Create a board memory entry and notify chat targets when needed.""" is_chat = payload.tags is not None and "chat" in payload.tags source = payload.source if is_chat and not source: @@ -255,5 +324,10 @@ async def create_board_memory( await session.commit() await session.refresh(memory) if is_chat: - await _notify_chat_targets(session=session, board=board, memory=memory, actor=actor) + await _notify_chat_targets( + session=session, + board=board, + memory=memory, + actor=actor, + ) return memory diff --git a/backend/app/api/board_onboarding.py b/backend/app/api/board_onboarding.py index e1a74cb..9caffc8 100644 --- a/backend/app/api/board_onboarding.py +++ b/backend/app/api/board_onboarding.py @@ -1,5 +1,4 @@ """Board onboarding endpoints for user/agent collaboration.""" -# ruff: noqa: E501 from __future__ import annotations @@ -201,16 +200,22 @@ async def start_onboarding( f"Board Name: {board.name}\n" "You are the main agent. Ask the user 6-10 focused questions total:\n" "- 3-6 questions to clarify the board goal.\n" - "- 1 question to choose a unique name for the board lead agent (first-name style).\n" - "- 2-4 questions to capture the user's preferences for how the board lead should work\n" + "- 1 question to choose a unique name for the board lead agent " + "(first-name style).\n" + "- 2-4 questions to capture the user's preferences for how the board " + "lead should work\n" " (communication style, autonomy, update cadence, and output formatting).\n" - '- Always include a final question (and only once): "Anything else we should know?"\n' + '- Always include a final question (and only once): "Anything else we ' + 'should know?"\n' " (constraints, context, preferences). This MUST be the last question.\n" ' Provide an option like "Yes (I\'ll type it)" so they can enter free-text.\n' " Do NOT ask for additional context on earlier questions.\n" - " Only include a free-text option on earlier questions if a typed answer is necessary;\n" - ' when you do, make the option label include "I\'ll type it" (e.g., "Other (I\'ll type it)").\n' - '- If the user sends an "Additional context" message later, incorporate it and resend status=complete\n' + " Only include a free-text option on earlier questions if a typed " + "answer is necessary;\n" + ' when you do, make the option label include "I\'ll type it" ' + '(e.g., "Other (I\'ll type it)").\n' + '- If the user sends an "Additional context" message later, incorporate ' + "it and resend status=complete\n" " to update the draft (until the user confirms).\n" "Do NOT respond in OpenClaw chat.\n" "All onboarding responses MUST be sent to Mission Control via API.\n" @@ -222,24 +227,37 @@ async def start_onboarding( f'curl -s -X POST "{base_url}/api/v1/agent/boards/{board.id}/onboarding" ' '-H "X-Agent-Token: $AUTH_TOKEN" ' '-H "Content-Type: application/json" ' - '-d \'{"question":"...","options":[{"id":"1","label":"..."},{"id":"2","label":"..."}]}\'\n' + '-d \'{"question":"...","options":[{"id":"1","label":"..."},' + '{"id":"2","label":"..."}]}\'\n' "COMPLETION example (send JSON body exactly as shown):\n" f'curl -s -X POST "{base_url}/api/v1/agent/boards/{board.id}/onboarding" ' '-H "X-Agent-Token: $AUTH_TOKEN" ' '-H "Content-Type: application/json" ' - '-d \'{"status":"complete","board_type":"goal","objective":"...","success_metrics":{"metric":"...","target":"..."},"target_date":"YYYY-MM-DD","user_profile":{"preferred_name":"...","pronouns":"...","timezone":"...","notes":"...","context":"..."},"lead_agent":{"name":"Ava","identity_profile":{"role":"Board Lead","communication_style":"direct, concise, practical","emoji":":gear:"},"autonomy_level":"balanced","verbosity":"concise","output_format":"bullets","update_cadence":"daily","custom_instructions":"..."}}\'\n' + '-d \'{"status":"complete","board_type":"goal","objective":"...",' + '"success_metrics":{"metric":"...","target":"..."},' + '"target_date":"YYYY-MM-DD",' + '"user_profile":{"preferred_name":"...","pronouns":"...",' + '"timezone":"...","notes":"...","context":"..."},' + '"lead_agent":{"name":"Ava","identity_profile":{"role":"Board Lead",' + '"communication_style":"direct, concise, practical","emoji":":gear:"},' + '"autonomy_level":"balanced","verbosity":"concise",' + '"output_format":"bullets","update_cadence":"daily",' + '"custom_instructions":"..."}}\'\n' "ENUMS:\n" "- board_type: goal | general\n" "- lead_agent.autonomy_level: ask_first | balanced | autonomous\n" "- lead_agent.verbosity: concise | balanced | detailed\n" "- lead_agent.output_format: bullets | mixed | narrative\n" "- lead_agent.update_cadence: asap | hourly | daily | weekly\n" - "QUESTION FORMAT (one question per response, no arrays, no markdown, no extra text):\n" + "QUESTION FORMAT (one question per response, no arrays, no markdown, " + "no extra text):\n" '{"question":"...","options":[{"id":"1","label":"..."},{"id":"2","label":"..."}]}\n' "Do NOT wrap questions in a list. Do NOT add commentary.\n" "When you have enough info, send one final response with status=complete.\n" - "The completion payload must include board_type. If board_type=goal, include objective + success_metrics.\n" - "Also include user_profile + lead_agent to configure the board lead's working style.\n" + "The completion payload must include board_type. If board_type=goal, " + "include objective + success_metrics.\n" + "Also include user_profile + lead_agent to configure the board lead's " + "working style.\n" ) try: diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index e7571be..dfe623e 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,19 +1,18 @@ +"""Reusable FastAPI dependencies for auth and board/task access.""" + from __future__ import annotations from dataclasses import dataclass -from typing import Literal +from typing import TYPE_CHECKING, Literal from fastapi import Depends, HTTPException, status -from sqlmodel.ext.asyncio.session import AsyncSession from app.core.agent_auth import AgentAuthContext, get_agent_auth_context_optional from app.core.auth import AuthContext, get_auth_context, get_auth_context_optional from app.db.session import get_session -from app.models.agents import Agent from app.models.boards import Board from app.models.organizations import Organization from app.models.tasks import Task -from app.models.users import User from app.services.admin_access import require_admin from app.services.organizations import ( OrganizationContext, @@ -23,23 +22,38 @@ from app.services.organizations import ( require_board_access, ) +if TYPE_CHECKING: + from sqlmodel.ext.asyncio.session import AsyncSession -def require_admin_auth(auth: AuthContext = Depends(get_auth_context)) -> AuthContext: + from app.models.agents import Agent + from app.models.users import User + +AUTH_DEP = Depends(get_auth_context) +AUTH_OPTIONAL_DEP = Depends(get_auth_context_optional) +AGENT_AUTH_OPTIONAL_DEP = Depends(get_agent_auth_context_optional) +SESSION_DEP = Depends(get_session) + + +def require_admin_auth(auth: AuthContext = AUTH_DEP) -> AuthContext: + """Require an authenticated admin user.""" require_admin(auth) return auth @dataclass class ActorContext: + """Authenticated actor context for user or agent callers.""" + actor_type: Literal["user", "agent"] user: User | None = None agent: Agent | None = None def require_admin_or_agent( - auth: AuthContext | None = Depends(get_auth_context_optional), - agent_auth: AgentAuthContext | None = Depends(get_agent_auth_context_optional), + auth: AuthContext | None = AUTH_OPTIONAL_DEP, + agent_auth: AgentAuthContext | None = AGENT_AUTH_OPTIONAL_DEP, ) -> ActorContext: + """Authorize either an admin user or an authenticated agent.""" if auth is not None: require_admin(auth) return ActorContext(actor_type="user", user=auth.user) @@ -48,10 +62,14 @@ def require_admin_or_agent( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) +ACTOR_DEP = Depends(require_admin_or_agent) + + async def require_org_member( - auth: AuthContext = Depends(get_auth_context), - session: AsyncSession = Depends(get_session), + auth: AuthContext = AUTH_DEP, + session: AsyncSession = SESSION_DEP, ) -> OrganizationContext: + """Resolve and require active organization membership for the current user.""" if auth.user is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) member = await get_active_membership(session, auth.user) @@ -59,15 +77,21 @@ async def require_org_member( member = await ensure_member_for_user(session, auth.user) if member is None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - organization = await Organization.objects.by_id(member.organization_id).first(session) + organization = await Organization.objects.by_id(member.organization_id).first( + session, + ) if organization is None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return OrganizationContext(organization=organization, member=member) +ORG_MEMBER_DEP = Depends(require_org_member) + + async def require_org_admin( - ctx: OrganizationContext = Depends(require_org_member), + ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> OrganizationContext: + """Require organization-admin membership privileges.""" if not is_org_admin(ctx.member): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return ctx @@ -75,8 +99,9 @@ async def require_org_admin( async def get_board_or_404( board_id: str, - session: AsyncSession = Depends(get_session), + session: AsyncSession = SESSION_DEP, ) -> Board: + """Load a board by id or raise HTTP 404.""" board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) @@ -85,9 +110,10 @@ async def get_board_or_404( async def get_board_for_actor_read( board_id: str, - session: AsyncSession = Depends(get_session), - actor: ActorContext = Depends(require_admin_or_agent), + session: AsyncSession = SESSION_DEP, + actor: ActorContext = ACTOR_DEP, ) -> Board: + """Load a board and enforce actor read access.""" board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) @@ -103,9 +129,10 @@ async def get_board_for_actor_read( async def get_board_for_actor_write( board_id: str, - session: AsyncSession = Depends(get_session), - actor: ActorContext = Depends(require_admin_or_agent), + session: AsyncSession = SESSION_DEP, + actor: ActorContext = ACTOR_DEP, ) -> Board: + """Load a board and enforce actor write access.""" board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) @@ -121,9 +148,10 @@ async def get_board_for_actor_write( async def get_board_for_user_read( board_id: str, - session: AsyncSession = Depends(get_session), - auth: AuthContext = Depends(get_auth_context), + session: AsyncSession = SESSION_DEP, + auth: AuthContext = AUTH_DEP, ) -> Board: + """Load a board and enforce authenticated-user read access.""" board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) @@ -135,9 +163,10 @@ async def get_board_for_user_read( async def get_board_for_user_write( board_id: str, - session: AsyncSession = Depends(get_session), - auth: AuthContext = Depends(get_auth_context), + session: AsyncSession = SESSION_DEP, + auth: AuthContext = AUTH_DEP, ) -> Board: + """Load a board and enforce authenticated-user write access.""" board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) @@ -147,11 +176,15 @@ async def get_board_for_user_write( return board +BOARD_READ_DEP = Depends(get_board_for_actor_read) + + async def get_task_or_404( task_id: str, - board: Board = Depends(get_board_for_actor_read), - session: AsyncSession = Depends(get_session), + board: Board = BOARD_READ_DEP, + session: AsyncSession = SESSION_DEP, ) -> Task: + """Load a task for a board or raise HTTP 404.""" task = await Task.objects.by_id(task_id).first(session) if task is None or task.board_id != board.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) diff --git a/backend/app/api/gateway.py b/backend/app/api/gateway.py index f22a5b4..439e244 100644 --- a/backend/app/api/gateway.py +++ b/backend/app/api/gateway.py @@ -1,7 +1,10 @@ +"""Gateway inspection and session-management endpoints.""" + from __future__ import annotations +from typing import TYPE_CHECKING + from fastapi import APIRouter, Depends, HTTPException, Query, status -from sqlmodel.ext.asyncio.session import AsyncSession from app.api.deps import require_org_admin from app.core.auth import AuthContext, get_auth_context @@ -21,7 +24,6 @@ from app.integrations.openclaw_gateway_protocol import ( ) from app.models.boards import Board from app.models.gateways import Gateway -from app.models.users import User from app.schemas.common import OkResponse from app.schemas.gateway_api import ( GatewayCommandsResponse, @@ -34,32 +36,48 @@ from app.schemas.gateway_api import ( ) from app.services.organizations import OrganizationContext, require_board_access +if TYPE_CHECKING: + from sqlmodel.ext.asyncio.session import AsyncSession + + from app.models.users import User + router = APIRouter(prefix="/gateways", tags=["gateways"]) +SESSION_DEP = Depends(get_session) +AUTH_DEP = Depends(get_auth_context) +ORG_ADMIN_DEP = Depends(require_org_admin) +BOARD_ID_QUERY = Query(default=None) +RESOLVE_QUERY_DEP = Depends() + + +def _query_to_resolve_input(params: GatewayResolveQuery) -> GatewayResolveQuery: + return params + + +RESOLVE_INPUT_DEP = Depends(_query_to_resolve_input) async def _resolve_gateway( session: AsyncSession, - board_id: str | None, - gateway_url: str | None, - gateway_token: str | None, - gateway_main_session_key: str | None, + params: GatewayResolveQuery, *, user: User | None = None, ) -> tuple[Board | None, GatewayClientConfig, str | None]: - if gateway_url: + if params.gateway_url: return ( None, - GatewayClientConfig(url=gateway_url, token=gateway_token), - gateway_main_session_key, + GatewayClientConfig(url=params.gateway_url, token=params.gateway_token), + params.gateway_main_session_key, ) - if not board_id: + if not params.board_id: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="board_id or gateway_url is required", ) - board = await Board.objects.by_id(board_id).first(session) + board = await Board.objects.by_id(params.board_id).first(session) if board is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Board not found", + ) if user is not None: await require_board_access(session, user=user, board=board, write=False) if not board.gateway_id: @@ -86,14 +104,12 @@ async def _resolve_gateway( async def _require_gateway( - session: AsyncSession, board_id: str | None, *, user: User | None = None + session: AsyncSession, board_id: str | None, *, user: User | None = None, ) -> tuple[Board, GatewayClientConfig, str | None]: + params = GatewayResolveQuery(board_id=board_id) board, config, main_session = await _resolve_gateway( session, - board_id, - None, - None, - None, + params, user=user, ) if board is None: @@ -106,17 +122,15 @@ async def _require_gateway( @router.get("/status", response_model=GatewaysStatusResponse) async def gateways_status( - params: GatewayResolveQuery = Depends(), - session: AsyncSession = Depends(get_session), - auth: AuthContext = Depends(get_auth_context), - ctx: OrganizationContext = Depends(require_org_admin), + params: GatewayResolveQuery = RESOLVE_INPUT_DEP, + session: AsyncSession = SESSION_DEP, + auth: AuthContext = AUTH_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> GatewaysStatusResponse: + """Return gateway connectivity and session status.""" board, config, main_session = await _resolve_gateway( session, - params.board_id, - params.gateway_url, - params.gateway_token, - params.gateway_main_session_key, + params, user=auth.user, ) if board is not None and board.organization_id != ctx.organization.id: @@ -131,7 +145,9 @@ async def gateways_status( main_session_error: str | None = None if main_session: try: - ensured = await ensure_session(main_session, config=config, label="Main Agent") + ensured = await ensure_session( + main_session, config=config, label="Main Agent", + ) if isinstance(ensured, dict): main_session_entry = ensured.get("entry") or ensured except OpenClawGatewayError as exc: @@ -146,22 +162,23 @@ async def gateways_status( main_session_error=main_session_error, ) except OpenClawGatewayError as exc: - return GatewaysStatusResponse(connected=False, gateway_url=config.url, error=str(exc)) + return GatewaysStatusResponse( + connected=False, gateway_url=config.url, error=str(exc), + ) @router.get("/sessions", response_model=GatewaySessionsResponse) async def list_gateway_sessions( - board_id: str | None = Query(default=None), - session: AsyncSession = Depends(get_session), - auth: AuthContext = Depends(get_auth_context), - ctx: OrganizationContext = Depends(require_org_admin), + board_id: str | None = BOARD_ID_QUERY, + session: AsyncSession = SESSION_DEP, + auth: AuthContext = AUTH_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> GatewaySessionsResponse: + """List sessions for a gateway associated with a board.""" + params = GatewayResolveQuery(board_id=board_id) board, config, main_session = await _resolve_gateway( session, - board_id, - None, - None, - None, + params, user=auth.user, ) if board is not None and board.organization_id != ctx.organization.id: @@ -169,7 +186,9 @@ async def list_gateway_sessions( try: sessions = await openclaw_call("sessions.list", config=config) except OpenClawGatewayError as exc: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), + ) from exc if isinstance(sessions, dict): sessions_list = list(sessions.get("sessions") or []) else: @@ -178,7 +197,9 @@ async def list_gateway_sessions( main_session_entry: object | None = None if main_session: try: - ensured = await ensure_session(main_session, config=config, label="Main Agent") + ensured = await ensure_session( + main_session, config=config, label="Main Agent", + ) if isinstance(ensured, dict): main_session_entry = ensured.get("entry") or ensured except OpenClawGatewayError: @@ -191,70 +212,103 @@ async def list_gateway_sessions( ) +async def _list_sessions(config: GatewayClientConfig) -> list[dict[str, object]]: + sessions = await openclaw_call("sessions.list", config=config) + if isinstance(sessions, dict): + raw_items = sessions.get("sessions") or [] + else: + raw_items = sessions or [] + return [ + item + for item in raw_items + if isinstance(item, dict) + ] + + +async def _with_main_session( + sessions_list: list[dict[str, object]], + *, + config: GatewayClientConfig, + main_session: str | None, +) -> list[dict[str, object]]: + if not main_session or any( + item.get("key") == main_session for item in sessions_list + ): + return sessions_list + try: + await ensure_session(main_session, config=config, label="Main Agent") + return await _list_sessions(config) + except OpenClawGatewayError: + return sessions_list + + @router.get("/sessions/{session_id}", response_model=GatewaySessionResponse) async def get_gateway_session( session_id: str, - board_id: str | None = Query(default=None), - session: AsyncSession = Depends(get_session), - auth: AuthContext = Depends(get_auth_context), - ctx: OrganizationContext = Depends(require_org_admin), + board_id: str | None = BOARD_ID_QUERY, + session: AsyncSession = SESSION_DEP, + auth: AuthContext = AUTH_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> GatewaySessionResponse: + """Get a specific gateway session by key.""" + params = GatewayResolveQuery(board_id=board_id) board, config, main_session = await _resolve_gateway( session, - board_id, - None, - None, - None, + params, user=auth.user, ) if board is not None and board.organization_id != ctx.organization.id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) try: - sessions = await openclaw_call("sessions.list", config=config) + sessions_list = await _list_sessions(config) except OpenClawGatewayError as exc: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc - if isinstance(sessions, dict): - sessions_list = list(sessions.get("sessions") or []) - else: - sessions_list = list(sessions or []) - if main_session and not any(item.get("key") == main_session for item in sessions_list): - try: - await ensure_session(main_session, config=config, label="Main Agent") - refreshed = await openclaw_call("sessions.list", config=config) - if isinstance(refreshed, dict): - sessions_list = list(refreshed.get("sessions") or []) - else: - sessions_list = list(refreshed or []) - except OpenClawGatewayError: - pass - session_entry = next((item for item in sessions_list if item.get("key") == session_id), None) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), + ) from exc + sessions_list = await _with_main_session( + sessions_list, + config=config, + main_session=main_session, + ) + session_entry = next( + (item for item in sessions_list if item.get("key") == session_id), None, + ) if session_entry is None and main_session and session_id == main_session: try: - ensured = await ensure_session(main_session, config=config, label="Main Agent") + ensured = await ensure_session( + main_session, config=config, label="Main Agent", + ) if isinstance(ensured, dict): session_entry = ensured.get("entry") or ensured except OpenClawGatewayError: session_entry = None if session_entry is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Session not found", + ) return GatewaySessionResponse(session=session_entry) -@router.get("/sessions/{session_id}/history", response_model=GatewaySessionHistoryResponse) +@router.get( + "/sessions/{session_id}/history", response_model=GatewaySessionHistoryResponse, +) async def get_session_history( session_id: str, - board_id: str | None = Query(default=None), - session: AsyncSession = Depends(get_session), - auth: AuthContext = Depends(get_auth_context), - ctx: OrganizationContext = Depends(require_org_admin), + board_id: str | None = BOARD_ID_QUERY, + session: AsyncSession = SESSION_DEP, + auth: AuthContext = AUTH_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> GatewaySessionHistoryResponse: + """Fetch chat history for a gateway session.""" board, config, _ = await _require_gateway(session, board_id, user=auth.user) if board.organization_id != ctx.organization.id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) try: history = await get_chat_history(session_id, config=config) except OpenClawGatewayError as exc: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), + ) from exc if isinstance(history, dict) and isinstance(history.get("messages"), list): return GatewaySessionHistoryResponse(history=history["messages"]) return GatewaySessionHistoryResponse(history=list(history or [])) @@ -264,14 +318,14 @@ async def get_session_history( async def send_gateway_session_message( session_id: str, payload: GatewaySessionMessageRequest, - board_id: str | None = Query(default=None), - session: AsyncSession = Depends(get_session), - auth: AuthContext = Depends(get_auth_context), - ctx: OrganizationContext = Depends(require_org_admin), + board_id: str | None = BOARD_ID_QUERY, + session: AsyncSession = SESSION_DEP, + auth: AuthContext = AUTH_DEP, ) -> OkResponse: - board, config, main_session = await _require_gateway(session, board_id, user=auth.user) - if board.organization_id != ctx.organization.id: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + """Send a message into a specific gateway session.""" + board, config, main_session = await _require_gateway( + session, board_id, user=auth.user, + ) if auth.user is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) await require_board_access(session, user=auth.user, board=board, write=True) @@ -280,15 +334,18 @@ async def send_gateway_session_message( await ensure_session(main_session, config=config, label="Main Agent") await send_message(payload.content, session_key=session_id, config=config) except OpenClawGatewayError as exc: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), + ) from exc return OkResponse() @router.get("/commands", response_model=GatewayCommandsResponse) async def gateway_commands( - auth: AuthContext = Depends(get_auth_context), - _ctx: OrganizationContext = Depends(require_org_admin), + _auth: AuthContext = AUTH_DEP, + _ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> GatewayCommandsResponse: + """Return supported gateway protocol methods and events.""" return GatewayCommandsResponse( protocol_version=PROTOCOL_VERSION, methods=GATEWAY_METHODS, diff --git a/backend/app/api/gateways.py b/backend/app/api/gateways.py index b37264e..5960682 100644 --- a/backend/app/api/gateways.py +++ b/backend/app/api/gateways.py @@ -1,10 +1,13 @@ +"""Gateway CRUD and template synchronization endpoints.""" + from __future__ import annotations +from dataclasses import dataclass +from typing import TYPE_CHECKING from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlmodel import col -from sqlmodel.ext.asyncio.session import AsyncSession from app.api.deps import require_org_admin from app.core.agent_tokens import generate_agent_token, hash_agent_token @@ -14,7 +17,11 @@ from app.db import crud from app.db.pagination import paginate from app.db.session import get_session from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig -from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message +from app.integrations.openclaw_gateway import ( + OpenClawGatewayError, + ensure_session, + send_message, +) from app.models.agents import Agent from app.models.gateways import Gateway from app.schemas.common import OkResponse @@ -25,11 +32,61 @@ from app.schemas.gateways import ( GatewayUpdate, ) from app.schemas.pagination import DefaultLimitOffsetPage -from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_main_agent -from app.services.organizations import OrganizationContext -from app.services.template_sync import sync_gateway_templates as sync_gateway_templates_service +from app.services.agent_provisioning import ( + DEFAULT_HEARTBEAT_CONFIG, + provision_main_agent, +) +from app.services.template_sync import ( + GatewayTemplateSyncOptions, +) +from app.services.template_sync import ( + sync_gateway_templates as sync_gateway_templates_service, +) + +if TYPE_CHECKING: + from sqlmodel.ext.asyncio.session import AsyncSession + + from app.services.organizations import OrganizationContext router = APIRouter(prefix="/gateways", tags=["gateways"]) +SESSION_DEP = Depends(get_session) +AUTH_DEP = Depends(get_auth_context) +ORG_ADMIN_DEP = Depends(require_org_admin) +INCLUDE_MAIN_QUERY = Query(default=True) +RESET_SESSIONS_QUERY = Query(default=False) +ROTATE_TOKENS_QUERY = Query(default=False) +FORCE_BOOTSTRAP_QUERY = Query(default=False) +BOARD_ID_QUERY = Query(default=None) +_RUNTIME_TYPE_REFERENCES = (UUID,) + + +@dataclass(frozen=True) +class _TemplateSyncQuery: + include_main: bool + reset_sessions: bool + rotate_tokens: bool + force_bootstrap: bool + board_id: UUID | None + + +def _template_sync_query( + *, + include_main: bool = INCLUDE_MAIN_QUERY, + reset_sessions: bool = RESET_SESSIONS_QUERY, + rotate_tokens: bool = ROTATE_TOKENS_QUERY, + force_bootstrap: bool = FORCE_BOOTSTRAP_QUERY, + board_id: UUID | None = BOARD_ID_QUERY, +) -> _TemplateSyncQuery: + return _TemplateSyncQuery( + include_main=include_main, + reset_sessions=reset_sessions, + rotate_tokens=rotate_tokens, + force_bootstrap=force_bootstrap, + board_id=board_id, + ) + + +SYNC_QUERY_DEP = Depends(_template_sync_query) def _main_agent_name(gateway: Gateway) -> str: @@ -48,7 +105,9 @@ async def _require_gateway( .first(session) ) if gateway is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found", + ) return gateway @@ -59,14 +118,18 @@ async def _find_main_agent( previous_session_key: str | None = None, ) -> Agent | None: if gateway.main_session_key: - agent = await Agent.objects.filter_by(openclaw_session_id=gateway.main_session_key).first( - session + agent = await Agent.objects.filter_by( + openclaw_session_id=gateway.main_session_key, + ).first( + session, ) if agent: return agent if previous_session_key: - agent = await Agent.objects.filter_by(openclaw_session_id=previous_session_key).first( - session + agent = await Agent.objects.filter_by( + openclaw_session_id=previous_session_key, + ).first( + session, ) if agent: return agent @@ -85,13 +148,17 @@ async def _ensure_main_agent( gateway: Gateway, auth: AuthContext, *, - previous_name: str | None = None, - previous_session_key: str | None = None, + previous: tuple[str | None, str | None] | None = None, action: str = "provision", ) -> Agent | None: if not gateway.url or not gateway.main_session_key: return None - agent = await _find_main_agent(session, gateway, previous_name, previous_session_key) + agent = await _find_main_agent( + session, + gateway, + previous_name=previous[0] if previous else None, + previous_session_key=previous[1] if previous else None, + ) if agent is None: agent = Agent( name=_main_agent_name(gateway), @@ -130,7 +197,8 @@ async def _ensure_main_agent( ( f"Hello {agent.name}. Your gateway provisioning was updated.\n\n" "Please re-read AGENTS.md, USER.md, HEARTBEAT.md, and TOOLS.md. " - "If BOOTSTRAP.md exists, run it once then delete it. Begin heartbeats after startup." + "If BOOTSTRAP.md exists, run it once then delete it. " + "Begin heartbeats after startup." ), session_key=gateway.main_session_key, config=GatewayClientConfig(url=gateway.url, token=gateway.token), @@ -144,9 +212,10 @@ async def _ensure_main_agent( @router.get("", response_model=DefaultLimitOffsetPage[GatewayRead]) async def list_gateways( - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_admin), + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> DefaultLimitOffsetPage[GatewayRead]: + """List gateways for the caller's organization.""" statement = ( Gateway.objects.filter_by(organization_id=ctx.organization.id) .order_by(col(Gateway.created_at).desc()) @@ -158,10 +227,11 @@ async def list_gateways( @router.post("", response_model=GatewayRead) async def create_gateway( payload: GatewayCreate, - session: AsyncSession = Depends(get_session), - auth: AuthContext = Depends(get_auth_context), - ctx: OrganizationContext = Depends(require_org_admin), + session: AsyncSession = SESSION_DEP, + auth: AuthContext = AUTH_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> Gateway: + """Create a gateway and provision or refresh its main agent.""" data = payload.model_dump() data["organization_id"] = ctx.organization.id gateway = await crud.create(session, Gateway, **data) @@ -172,9 +242,10 @@ async def create_gateway( @router.get("/{gateway_id}", response_model=GatewayRead) async def get_gateway( gateway_id: UUID, - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_admin), + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> Gateway: + """Return one gateway by id for the caller's organization.""" return await _require_gateway( session, gateway_id=gateway_id, @@ -186,10 +257,11 @@ async def get_gateway( async def update_gateway( gateway_id: UUID, payload: GatewayUpdate, - session: AsyncSession = Depends(get_session), - auth: AuthContext = Depends(get_auth_context), - ctx: OrganizationContext = Depends(require_org_admin), + session: AsyncSession = SESSION_DEP, + auth: AuthContext = AUTH_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> Gateway: + """Patch a gateway and refresh the main-agent provisioning state.""" gateway = await _require_gateway( session, gateway_id=gateway_id, @@ -203,8 +275,7 @@ async def update_gateway( session, gateway, auth, - previous_name=previous_name, - previous_session_key=previous_session_key, + previous=(previous_name, previous_session_key), action="update", ) return gateway @@ -213,15 +284,12 @@ async def update_gateway( @router.post("/{gateway_id}/templates/sync", response_model=GatewayTemplatesSyncResult) async def sync_gateway_templates( gateway_id: UUID, - include_main: bool = Query(default=True), - reset_sessions: bool = Query(default=False), - rotate_tokens: bool = Query(default=False), - force_bootstrap: bool = Query(default=False), - board_id: UUID | None = Query(default=None), - session: AsyncSession = Depends(get_session), - auth: AuthContext = Depends(get_auth_context), - ctx: OrganizationContext = Depends(require_org_admin), + sync_query: _TemplateSyncQuery = SYNC_QUERY_DEP, + session: AsyncSession = SESSION_DEP, + auth: AuthContext = AUTH_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> GatewayTemplatesSyncResult: + """Sync templates for a gateway and optionally rotate runtime settings.""" gateway = await _require_gateway( session, gateway_id=gateway_id, @@ -230,21 +298,24 @@ async def sync_gateway_templates( return await sync_gateway_templates_service( session, gateway, - user=auth.user, - include_main=include_main, - reset_sessions=reset_sessions, - rotate_tokens=rotate_tokens, - force_bootstrap=force_bootstrap, - board_id=board_id, + GatewayTemplateSyncOptions( + user=auth.user, + include_main=sync_query.include_main, + reset_sessions=sync_query.reset_sessions, + rotate_tokens=sync_query.rotate_tokens, + force_bootstrap=sync_query.force_bootstrap, + board_id=sync_query.board_id, + ), ) @router.delete("/{gateway_id}", response_model=OkResponse) async def delete_gateway( gateway_id: UUID, - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_admin), + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> OkResponse: + """Delete a gateway in the caller's organization.""" gateway = await _require_gateway( session, gateway_id=gateway_id, diff --git a/backend/app/api/metrics.py b/backend/app/api/metrics.py index 4b2f0fa..fc8ba25 100644 --- a/backend/app/api/metrics.py +++ b/backend/app/api/metrics.py @@ -1,3 +1,5 @@ +"""Dashboard metric aggregation endpoints.""" + from __future__ import annotations from dataclasses import dataclass @@ -32,10 +34,16 @@ router = APIRouter(prefix="/metrics", tags=["metrics"]) OFFLINE_AFTER = timedelta(minutes=10) ERROR_EVENT_PATTERN = "%failed" +_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession) +RANGE_QUERY = Query(default="24h") +SESSION_DEP = Depends(get_session) +ORG_MEMBER_DEP = Depends(require_org_member) @dataclass(frozen=True) class RangeSpec: + """Resolved time-range specification for metric aggregation.""" + key: Literal["24h", "7d"] start: datetime end: datetime @@ -80,7 +88,8 @@ def _build_buckets(range_spec: RangeSpec) -> list[datetime]: def _series_from_mapping( - range_spec: RangeSpec, mapping: dict[datetime, float] + range_spec: RangeSpec, + mapping: dict[datetime, float], ) -> DashboardRangeSeries: points = [ DashboardSeriesPoint(period=bucket, value=float(mapping.get(bucket, 0))) @@ -94,7 +103,8 @@ def _series_from_mapping( def _wip_series_from_mapping( - range_spec: RangeSpec, mapping: dict[datetime, dict[str, int]] + range_spec: RangeSpec, + mapping: dict[datetime, dict[str, int]], ) -> DashboardWipRangeSeries: points: list[DashboardWipPoint] = [] for bucket in _build_buckets(range_spec): @@ -105,7 +115,7 @@ def _wip_series_from_mapping( inbox=values.get("inbox", 0), in_progress=values.get("in_progress", 0), review=values.get("review", 0), - ) + ), ) return DashboardWipRangeSeries( range=range_spec.key, @@ -115,7 +125,9 @@ def _wip_series_from_mapping( async def _query_throughput( - session: AsyncSession, range_spec: RangeSpec, board_ids: list[UUID] + session: AsyncSession, + range_spec: RangeSpec, + board_ids: list[UUID], ) -> DashboardRangeSeries: bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket") statement = ( @@ -135,7 +147,9 @@ async def _query_throughput( async def _query_cycle_time( - session: AsyncSession, range_spec: RangeSpec, board_ids: list[UUID] + session: AsyncSession, + range_spec: RangeSpec, + board_ids: list[UUID], ) -> DashboardRangeSeries: bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket") in_progress = cast(Task.in_progress_at, DateTime) @@ -158,9 +172,14 @@ async def _query_cycle_time( async def _query_error_rate( - session: AsyncSession, range_spec: RangeSpec, board_ids: list[UUID] + session: AsyncSession, + range_spec: RangeSpec, + board_ids: list[UUID], ) -> DashboardRangeSeries: - bucket_col = func.date_trunc(range_spec.bucket, ActivityEvent.created_at).label("bucket") + bucket_col = func.date_trunc( + range_spec.bucket, + ActivityEvent.created_at, + ).label("bucket") error_case = case( ( col(ActivityEvent.event_type).like(ERROR_EVENT_PATTERN), @@ -190,7 +209,9 @@ async def _query_error_rate( async def _query_wip( - session: AsyncSession, range_spec: RangeSpec, board_ids: list[UUID] + session: AsyncSession, + range_spec: RangeSpec, + board_ids: list[UUID], ) -> DashboardWipRangeSeries: bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket") inbox_case = case((col(Task.status) == "inbox", 1), else_=0) @@ -222,7 +243,10 @@ async def _query_wip( return _wip_series_from_mapping(range_spec, mapping) -async def _median_cycle_time_7d(session: AsyncSession, board_ids: list[UUID]) -> float | None: +async def _median_cycle_time_7d( + session: AsyncSession, + board_ids: list[UUID], +) -> float | None: now = utcnow() start = now - timedelta(days=7) in_progress = cast(Task.in_progress_at, DateTime) @@ -248,7 +272,9 @@ async def _median_cycle_time_7d(session: AsyncSession, board_ids: list[UUID]) -> async def _error_rate_kpi( - session: AsyncSession, range_spec: RangeSpec, board_ids: list[UUID] + session: AsyncSession, + range_spec: RangeSpec, + board_ids: list[UUID], ) -> float: error_case = case( ( @@ -302,12 +328,13 @@ async def _tasks_in_progress(session: AsyncSession, board_ids: list[UUID]) -> in @router.get("/dashboard", response_model=DashboardMetrics) async def dashboard_metrics( - range: Literal["24h", "7d"] = Query(default="24h"), - session: AsyncSession = Depends(get_session), - ctx: OrganizationContext = Depends(require_org_member), + range_key: Literal["24h", "7d"] = RANGE_QUERY, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> DashboardMetrics: - primary = _resolve_range(range) - comparison = _comparison_range(range) + """Return dashboard KPIs and time-series data for accessible boards.""" + primary = _resolve_range(range_key) + comparison = _comparison_range(range_key) board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) throughput_primary = await _query_throughput(session, primary, board_ids) diff --git a/backend/app/db/crud.py b/backend/app/db/crud.py index bd61f6b..d9f67f3 100644 --- a/backend/app/db/crud.py +++ b/backend/app/db/crud.py @@ -1,27 +1,37 @@ +"""Generic asynchronous CRUD helpers for SQLModel entities.""" + from __future__ import annotations -from collections.abc import Iterable, Mapping -from typing import Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from sqlalchemy import delete as sql_delete from sqlalchemy import update as sql_update from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlmodel import SQLModel, select -from sqlmodel.ext.asyncio.session import AsyncSession -from sqlmodel.sql.expression import SelectOfScalar + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + + from sqlmodel.ext.asyncio.session import AsyncSession + from sqlmodel.sql.expression import SelectOfScalar ModelT = TypeVar("ModelT", bound=SQLModel) -class DoesNotExist(LookupError): - pass +class DoesNotExistError(LookupError): + """Raised when a query expected one row but found none.""" -class MultipleObjectsReturned(LookupError): - pass +class MultipleObjectsReturnedError(LookupError): + """Raised when a query expected one row but found many.""" + + +DoesNotExist = DoesNotExistError +MultipleObjectsReturned = MultipleObjectsReturnedError async def _flush_or_rollback(session: AsyncSession) -> None: + """Flush changes and rollback on SQLAlchemy errors.""" try: await session.flush() except SQLAlchemyError: @@ -30,6 +40,7 @@ async def _flush_or_rollback(session: AsyncSession) -> None: async def _commit_or_rollback(session: AsyncSession) -> None: + """Commit transaction and rollback on SQLAlchemy errors.""" try: await session.commit() except SQLAlchemyError: @@ -37,31 +48,50 @@ async def _commit_or_rollback(session: AsyncSession) -> None: raise -def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]: +def _lookup_statement( + model: type[ModelT], + lookup: Mapping[str, Any], +) -> SelectOfScalar[ModelT]: + """Build a select statement with equality filters from lookup values.""" stmt = select(model) for key, value in lookup.items(): stmt = stmt.where(getattr(model, key) == value) return stmt -async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None: +async def get_by_id( + session: AsyncSession, + model: type[ModelT], + obj_id: object, +) -> ModelT | None: + """Fetch one model instance by id or return None.""" stmt = _lookup_statement(model, {"id": obj_id}).limit(1) return (await session.exec(stmt)).first() -async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT: +async def get( + session: AsyncSession, + model: type[ModelT], + **lookup: object, +) -> ModelT: + """Fetch exactly one model instance by lookup values.""" stmt = _lookup_statement(model, lookup).limit(2) items = (await session.exec(stmt)).all() if not items: - raise DoesNotExist(f"{model.__name__} matching query does not exist.") + message = f"{model.__name__} matching query does not exist." + raise DoesNotExist(message) if len(items) > 1: - raise MultipleObjectsReturned( - f"Multiple {model.__name__} objects returned for lookup {lookup!r}." - ) + message = f"Multiple {model.__name__} objects returned for lookup {lookup!r}." + raise MultipleObjectsReturned(message) return items[0] -async def get_one_by(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT | None: +async def get_one_by( + session: AsyncSession, + model: type[ModelT], + **lookup: object, +) -> ModelT | None: + """Fetch the first model instance matching lookup values.""" stmt = _lookup_statement(model, lookup) return (await session.exec(stmt)).first() @@ -72,8 +102,9 @@ async def create( *, commit: bool = True, refresh: bool = True, - **data: Any, + **data: object, ) -> ModelT: + """Create, flush, optionally commit, and optionally refresh an object.""" obj = model.model_validate(data) session.add(obj) await _flush_or_rollback(session) @@ -91,6 +122,7 @@ async def save( commit: bool = True, refresh: bool = True, ) -> ModelT: + """Persist an existing object with optional commit and refresh.""" session.add(obj) await _flush_or_rollback(session) if commit: @@ -101,6 +133,7 @@ async def save( async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None: + """Delete an object with optional commit.""" await session.delete(obj) if commit: await _commit_or_rollback(session) @@ -113,8 +146,9 @@ async def list_by( order_by: Iterable[Any] = (), limit: int | None = None, offset: int | None = None, - **lookup: Any, + **lookup: object, ) -> list[ModelT]: + """List objects by lookup values with optional ordering and pagination.""" stmt = _lookup_statement(model, lookup) for ordering in order_by: stmt = stmt.order_by(ordering) @@ -125,11 +159,19 @@ async def list_by( return list(await session.exec(stmt)) -async def exists(session: AsyncSession, model: type[ModelT], **lookup: Any) -> bool: - return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None +async def exists(session: AsyncSession, model: type[ModelT], **lookup: object) -> bool: + """Return whether any object exists for lookup values.""" + return ( + (await session.exec(_lookup_statement(model, lookup).limit(1))).first() + is not None + ) -def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> SelectOfScalar[ModelT]: +def _criteria_statement( + model: type[ModelT], + criteria: tuple[Any, ...], +) -> SelectOfScalar[ModelT]: + """Build a select statement from variadic where criteria.""" stmt = select(model) if criteria: stmt = stmt.where(*criteria) @@ -139,9 +181,10 @@ def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> Selec async def list_where( session: AsyncSession, model: type[ModelT], - *criteria: Any, + *criteria: object, order_by: Iterable[Any] = (), ) -> list[ModelT]: + """List objects filtered by explicit SQL criteria.""" stmt = _criteria_statement(model, criteria) for ordering in order_by: stmt = stmt.order_by(ordering) @@ -151,9 +194,10 @@ async def list_where( async def delete_where( session: AsyncSession, model: type[ModelT], - *criteria: Any, + *criteria: object, commit: bool = False, ) -> int: + """Delete rows matching criteria and return affected row count.""" stmt: Any = sql_delete(model) if criteria: stmt = stmt.where(*criteria) @@ -167,18 +211,24 @@ async def delete_where( async def update_where( session: AsyncSession, model: type[ModelT], - *criteria: Any, + *criteria: object, updates: Mapping[str, Any] | None = None, - commit: bool = False, - exclude_none: bool = False, - allowed_fields: set[str] | None = None, - **update_fields: Any, + **options: object, ) -> int: + """Apply bulk updates by criteria and return affected row count.""" + commit = bool(options.pop("commit", False)) + exclude_none = bool(options.pop("exclude_none", False)) + allowed_fields_raw = options.pop("allowed_fields", None) + allowed_fields = ( + allowed_fields_raw + if isinstance(allowed_fields_raw, set) + else None + ) source_updates: dict[str, Any] = {} if updates: source_updates.update(dict(updates)) - if update_fields: - source_updates.update(update_fields) + if options: + source_updates.update(options) values: dict[str, Any] = {} for key, value in source_updates.items(): @@ -207,6 +257,7 @@ def apply_updates( exclude_none: bool = False, allowed_fields: set[str] | None = None, ) -> ModelT: + """Apply a mapping of field updates onto an object.""" for key, value in updates.items(): if allowed_fields is not None and key not in allowed_fields: continue @@ -220,12 +271,18 @@ async def patch( session: AsyncSession, obj: ModelT, updates: Mapping[str, Any], - *, - exclude_none: bool = False, - allowed_fields: set[str] | None = None, - commit: bool = True, - refresh: bool = True, + **options: object, ) -> ModelT: + """Apply partial updates and persist object.""" + exclude_none = bool(options.pop("exclude_none", False)) + allowed_fields_raw = options.pop("allowed_fields", None) + allowed_fields = ( + allowed_fields_raw + if isinstance(allowed_fields_raw, set) + else None + ) + commit = bool(options.pop("commit", True)) + refresh = bool(options.pop("refresh", True)) apply_updates( obj, updates, @@ -242,8 +299,9 @@ async def get_or_create( defaults: Mapping[str, Any] | None = None, commit: bool = True, refresh: bool = True, - **lookup: Any, + **lookup: object, ) -> tuple[ModelT, bool]: + """Get one object by lookup, or create it with defaults.""" stmt = _lookup_statement(model, lookup) existing = (await session.exec(stmt)).first() diff --git a/backend/app/integrations/openclaw_gateway.py b/backend/app/integrations/openclaw_gateway.py index 38c92bd..25a06bd 100644 --- a/backend/app/integrations/openclaw_gateway.py +++ b/backend/app/integrations/openclaw_gateway.py @@ -1,3 +1,5 @@ +"""OpenClaw gateway client helpers for websocket RPC calls.""" + from __future__ import annotations import asyncio @@ -14,16 +16,20 @@ from app.integrations.openclaw_gateway_protocol import PROTOCOL_VERSION class OpenClawGatewayError(RuntimeError): - pass + """Raised when OpenClaw gateway calls fail.""" @dataclass class OpenClawResponse: + """Container for raw OpenClaw payloads.""" + payload: Any @dataclass(frozen=True) class GatewayConfig: + """Connection configuration for the OpenClaw gateway.""" + url: str token: str | None = None @@ -31,7 +37,8 @@ class GatewayConfig: def _build_gateway_url(config: GatewayConfig) -> str: base_url = (config.url or "").strip() if not base_url: - raise OpenClawGatewayError("Gateway URL is not configured for this board.") + message = "Gateway URL is not configured for this board." + raise OpenClawGatewayError(message) token = config.token if not token: return base_url @@ -40,7 +47,10 @@ def _build_gateway_url(config: GatewayConfig) -> str: return urlunparse(parsed._replace(query=query)) -async def _await_response(ws: websockets.WebSocketClientProtocol, request_id: str) -> Any: +async def _await_response( + ws: websockets.WebSocketClientProtocol, + request_id: str, +) -> object: while True: raw = await ws.recv() data = json.loads(raw) @@ -53,15 +63,23 @@ async def _await_response(ws: websockets.WebSocketClientProtocol, request_id: st if data.get("id") == request_id: if data.get("error"): - raise OpenClawGatewayError(data["error"].get("message", "Gateway error")) + message = data["error"].get("message", "Gateway error") + raise OpenClawGatewayError(message) return data.get("result") async def _send_request( - ws: websockets.WebSocketClientProtocol, method: str, params: dict[str, Any] | None -) -> Any: + ws: websockets.WebSocketClientProtocol, + method: str, + params: dict[str, Any] | None, +) -> object: request_id = str(uuid4()) - message = {"type": "req", "id": request_id, "method": method, "params": params or {}} + message = { + "type": "req", + "id": request_id, + "method": method, + "params": params or {}, + } await ws.send(json.dumps(message)) return await _await_response(ws, request_id) @@ -109,7 +127,8 @@ async def openclaw_call( params: dict[str, Any] | None = None, *, config: GatewayConfig, -) -> Any: +) -> object: + """Call a gateway RPC method and return the result payload.""" gateway_url = _build_gateway_url(config) try: async with websockets.connect(gateway_url, ping_interval=None) as ws: @@ -138,7 +157,8 @@ async def send_message( session_key: str, config: GatewayConfig, deliver: bool = False, -) -> Any: +) -> object: + """Send a chat message to a session.""" params: dict[str, Any] = { "sessionKey": session_key, "message": message, @@ -152,14 +172,16 @@ async def get_chat_history( session_key: str, config: GatewayConfig, limit: int | None = None, -) -> Any: +) -> object: + """Fetch chat history for a session.""" params: dict[str, Any] = {"sessionKey": session_key} if limit is not None: params["limit"] = limit return await openclaw_call("chat.history", params, config=config) -async def delete_session(session_key: str, *, config: GatewayConfig) -> Any: +async def delete_session(session_key: str, *, config: GatewayConfig) -> object: + """Delete a session by key.""" return await openclaw_call("sessions.delete", {"key": session_key}, config=config) @@ -168,7 +190,8 @@ async def ensure_session( *, config: GatewayConfig, label: str | None = None, -) -> Any: +) -> object: + """Ensure a session exists and optionally update its label.""" params: dict[str, Any] = {"key": session_key} if label: params["label"] = label diff --git a/backend/app/schemas/agents.py b/backend/app/schemas/agents.py index 4fa2ca3..e37909b 100644 --- a/backend/app/schemas/agents.py +++ b/backend/app/schemas/agents.py @@ -1,3 +1,5 @@ +"""Pydantic/SQLModel schemas for agent API payloads.""" + from __future__ import annotations from collections.abc import Mapping @@ -10,6 +12,8 @@ from sqlmodel import SQLModel from app.schemas.common import NonEmptyStr +_RUNTIME_TYPE_REFERENCES = (datetime, UUID, NonEmptyStr) + def _normalize_identity_profile( profile: object, @@ -36,6 +40,8 @@ def _normalize_identity_profile( class AgentBase(SQLModel): + """Common fields shared by agent create/read/update payloads.""" + board_id: UUID | None = None name: NonEmptyStr status: str = "provisioning" @@ -46,7 +52,8 @@ class AgentBase(SQLModel): @field_validator("identity_template", "soul_template", mode="before") @classmethod - def normalize_templates(cls, value: Any) -> Any: + def normalize_templates(cls, value: object) -> object | None: + """Normalize blank template text to null.""" if value is None: return None if isinstance(value, str): @@ -56,15 +63,21 @@ class AgentBase(SQLModel): @field_validator("identity_profile", mode="before") @classmethod - def normalize_identity_profile(cls, value: Any) -> Any: + def normalize_identity_profile( + cls, + value: object, + ) -> dict[str, str] | None: + """Normalize identity-profile values into trimmed string mappings.""" return _normalize_identity_profile(value) class AgentCreate(AgentBase): - pass + """Payload for creating a new agent.""" class AgentUpdate(SQLModel): + """Payload for patching an existing agent.""" + board_id: UUID | None = None is_gateway_main: bool | None = None name: NonEmptyStr | None = None @@ -76,7 +89,8 @@ class AgentUpdate(SQLModel): @field_validator("identity_template", "soul_template", mode="before") @classmethod - def normalize_templates(cls, value: Any) -> Any: + def normalize_templates(cls, value: object) -> object | None: + """Normalize blank template text to null.""" if value is None: return None if isinstance(value, str): @@ -86,11 +100,17 @@ class AgentUpdate(SQLModel): @field_validator("identity_profile", mode="before") @classmethod - def normalize_identity_profile(cls, value: Any) -> Any: + def normalize_identity_profile( + cls, + value: object, + ) -> dict[str, str] | None: + """Normalize identity-profile values into trimmed string mappings.""" return _normalize_identity_profile(value) class AgentRead(AgentBase): + """Public agent representation returned by the API.""" + id: UUID is_board_lead: bool = False is_gateway_main: bool = False @@ -101,13 +121,19 @@ class AgentRead(AgentBase): class AgentHeartbeat(SQLModel): + """Heartbeat status payload sent by agents.""" + status: str | None = None class AgentHeartbeatCreate(AgentHeartbeat): + """Heartbeat payload used to create an agent lazily.""" + name: NonEmptyStr board_id: UUID | None = None class AgentNudge(SQLModel): + """Nudge message payload for pinging an agent.""" + message: NonEmptyStr diff --git a/backend/app/schemas/board_onboarding.py b/backend/app/schemas/board_onboarding.py index 5d64a9b..ba39b13 100644 --- a/backend/app/schemas/board_onboarding.py +++ b/backend/app/schemas/board_onboarding.py @@ -1,7 +1,9 @@ +"""Schemas used by the board-onboarding assistant flow.""" + from __future__ import annotations from datetime import datetime -from typing import Any, Literal, Self +from typing import Literal, Self from uuid import UUID from pydantic import Field, field_validator, model_validator @@ -9,17 +11,23 @@ from sqlmodel import SQLModel from app.schemas.common import NonEmptyStr +_RUNTIME_TYPE_REFERENCES = (datetime, UUID, NonEmptyStr) + class BoardOnboardingStart(SQLModel): - pass + """Start signal for initializing onboarding conversation.""" class BoardOnboardingAnswer(SQLModel): + """User answer payload for a single onboarding question.""" + answer: NonEmptyStr other_text: str | None = None class BoardOnboardingConfirm(SQLModel): + """Payload used to confirm generated onboarding draft fields.""" + board_type: str objective: str | None = None success_metrics: dict[str, object] | None = None @@ -27,23 +35,32 @@ class BoardOnboardingConfirm(SQLModel): @model_validator(mode="after") def validate_goal_fields(self) -> Self: - if self.board_type == "goal": - if not self.objective or not self.success_metrics: - raise ValueError("Confirmed goal boards require objective and success_metrics") + """Require goal metadata when the board type is `goal`.""" + if self.board_type == "goal" and ( + not self.objective or not self.success_metrics + ): + message = ( + "Confirmed goal boards require objective and success_metrics" + ) + raise ValueError(message) return self class BoardOnboardingQuestionOption(SQLModel): + """Selectable option for an onboarding question.""" + id: NonEmptyStr label: NonEmptyStr class BoardOnboardingAgentQuestion(SQLModel): + """Question payload emitted by the onboarding assistant.""" + question: NonEmptyStr options: list[BoardOnboardingQuestionOption] = Field(min_length=1) -def _normalize_optional_text(value: Any) -> Any: +def _normalize_optional_text(value: object) -> object | None: if value is None: return None if isinstance(value, str): @@ -53,6 +70,8 @@ def _normalize_optional_text(value: Any) -> Any: class BoardOnboardingUserProfile(SQLModel): + """User-profile preferences gathered during onboarding.""" + preferred_name: str | None = None pronouns: str | None = None timezone: str | None = None @@ -68,7 +87,8 @@ class BoardOnboardingUserProfile(SQLModel): mode="before", ) @classmethod - def normalize_text(cls, value: Any) -> Any: + def normalize_text(cls, value: object) -> object | None: + """Trim optional free-form profile text fields.""" return _normalize_optional_text(value) @@ -79,6 +99,8 @@ LeadAgentUpdateCadence = Literal["asap", "hourly", "daily", "weekly"] class BoardOnboardingLeadAgentDraft(SQLModel): + """Editable lead-agent draft configuration.""" + name: NonEmptyStr | None = None # role, communication_style, emoji are expected keys. identity_profile: dict[str, str] | None = None @@ -97,12 +119,17 @@ class BoardOnboardingLeadAgentDraft(SQLModel): mode="before", ) @classmethod - def normalize_text_fields(cls, value: Any) -> Any: + def normalize_text_fields(cls, value: object) -> object | None: + """Trim optional lead-agent preference fields.""" return _normalize_optional_text(value) @field_validator("identity_profile", mode="before") @classmethod - def normalize_identity_profile(cls, value: Any) -> Any: + def normalize_identity_profile( + cls, + value: object, + ) -> object | None: + """Normalize identity profile keys and values as trimmed strings.""" if value is None: return None if not isinstance(value, dict): @@ -121,6 +148,8 @@ class BoardOnboardingLeadAgentDraft(SQLModel): class BoardOnboardingAgentComplete(BoardOnboardingConfirm): + """Complete onboarding draft produced by the onboarding assistant.""" + status: Literal["complete"] user_profile: BoardOnboardingUserProfile | None = None lead_agent: BoardOnboardingLeadAgentDraft | None = None @@ -130,6 +159,8 @@ BoardOnboardingAgentUpdate = BoardOnboardingAgentComplete | BoardOnboardingAgent class BoardOnboardingRead(SQLModel): + """Stored onboarding session state returned by API endpoints.""" + id: UUID board_id: UUID session_key: str diff --git a/backend/app/services/agent_provisioning.py b/backend/app/services/agent_provisioning.py index 6621d61..21eed80 100644 --- a/backend/app/services/agent_provisioning.py +++ b/backend/app/services/agent_provisioning.py @@ -1,5 +1,4 @@ """Gateway-facing agent provisioning and cleanup helpers.""" -# ruff: noqa: EM101, TRY003 from __future__ import annotations @@ -176,7 +175,8 @@ def _heartbeat_template_name(agent: Agent) -> str: def _workspace_path(agent: Agent, workspace_root: str) -> str: if not workspace_root: - raise ValueError("gateway_workspace_root is required") + msg = "gateway_workspace_root is required" + raise ValueError(msg) root = workspace_root.rstrip("/") # Use agent key derived from session key when possible. This prevents collisions for # lead agents (session key includes board id) even if multiple boards share the same @@ -227,9 +227,11 @@ def _build_context( user: User | None, ) -> dict[str, str]: if not gateway.workspace_root: - raise ValueError("gateway_workspace_root is required") + msg = "gateway_workspace_root is required" + raise ValueError(msg) if not gateway.main_session_key: - raise ValueError("gateway_main_session_key is required") + msg = "gateway_main_session_key is required" + raise ValueError(msg) agent_id = str(agent.id) workspace_root = gateway.workspace_root workspace_path = _workspace_path(agent, workspace_root) @@ -485,15 +487,18 @@ async def _patch_gateway_agent_list( ) -> None: cfg = await openclaw_call("config.get", config=config) if not isinstance(cfg, dict): - raise OpenClawGatewayError("config.get returned invalid payload") + msg = "config.get returned invalid payload" + raise OpenClawGatewayError(msg) base_hash = cfg.get("hash") data = cfg.get("config") or cfg.get("parsed") or {} if not isinstance(data, dict): - raise OpenClawGatewayError("config.get returned invalid config") + msg = "config.get returned invalid config" + raise OpenClawGatewayError(msg) agents = data.get("agents") or {} lst = agents.get("list") or [] if not isinstance(lst, list): - raise OpenClawGatewayError("config agents.list is not a list") + msg = "config agents.list is not a list" + raise OpenClawGatewayError(msg) updated = False new_list: list[dict[str, Any]] = [] @@ -528,19 +533,23 @@ async def patch_gateway_agent_heartbeats( # noqa: C901 Each entry is (agent_id, workspace_path, heartbeat_dict). """ if not gateway.url: - raise OpenClawGatewayError("Gateway url is required") + msg = "Gateway url is required" + raise OpenClawGatewayError(msg) config = GatewayClientConfig(url=gateway.url, token=gateway.token) cfg = await openclaw_call("config.get", config=config) if not isinstance(cfg, dict): - raise OpenClawGatewayError("config.get returned invalid payload") + msg = "config.get returned invalid payload" + raise OpenClawGatewayError(msg) base_hash = cfg.get("hash") data = cfg.get("config") or cfg.get("parsed") or {} if not isinstance(data, dict): - raise OpenClawGatewayError("config.get returned invalid config") + msg = "config.get returned invalid config" + raise OpenClawGatewayError(msg) agents_section = data.get("agents") or {} lst = agents_section.get("list") or [] if not isinstance(lst, list): - raise OpenClawGatewayError("config agents.list is not a list") + msg = "config agents.list is not a list" + raise OpenClawGatewayError(msg) entry_by_id: dict[str, tuple[str, dict[str, Any]]] = { agent_id: (workspace_path, heartbeat) @@ -581,7 +590,8 @@ async def patch_gateway_agent_heartbeats( # noqa: C901 async def sync_gateway_agent_heartbeats(gateway: Gateway, agents: list[Agent]) -> None: """Sync current Agent.heartbeat_config values to the gateway config.""" if not gateway.workspace_root: - raise OpenClawGatewayError("gateway workspace_root is required") + msg = "gateway workspace_root is required" + raise OpenClawGatewayError(msg) entries: list[tuple[str, str, dict[str, Any]]] = [] for agent in agents: agent_id = _agent_key(agent) @@ -599,15 +609,18 @@ async def _remove_gateway_agent_list( ) -> None: cfg = await openclaw_call("config.get", config=config) if not isinstance(cfg, dict): - raise OpenClawGatewayError("config.get returned invalid payload") + msg = "config.get returned invalid payload" + raise OpenClawGatewayError(msg) base_hash = cfg.get("hash") data = cfg.get("config") or cfg.get("parsed") or {} if not isinstance(data, dict): - raise OpenClawGatewayError("config.get returned invalid config") + msg = "config.get returned invalid config" + raise OpenClawGatewayError(msg) agents = data.get("agents") or {} lst = agents.get("list") or [] if not isinstance(lst, list): - raise OpenClawGatewayError("config agents.list is not a list") + msg = "config agents.list is not a list" + raise OpenClawGatewayError(msg) new_list = [ entry @@ -658,7 +671,8 @@ async def provision_agent( # noqa: C901, PLR0912, PLR0913 if not gateway.url: return if not gateway.workspace_root: - raise ValueError("gateway_workspace_root is required") + msg = "gateway_workspace_root is required" + raise ValueError(msg) client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) session_key = _session_key(agent) await ensure_session(session_key, config=client_config, label=agent.name) @@ -734,7 +748,8 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913 if not gateway.url: return if not gateway.main_session_key: - raise ValueError("gateway main_session_key is required") + msg = "gateway main_session_key is required" + raise ValueError(msg) client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) await ensure_session( gateway.main_session_key, config=client_config, label="Main Agent", @@ -745,7 +760,8 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913 fallback_session_key=gateway.main_session_key, ) if not agent_id: - raise OpenClawGatewayError("Unable to resolve gateway main agent id") + msg = "Unable to resolve gateway main agent id" + raise OpenClawGatewayError(msg) context = _build_main_context(agent, gateway, auth_token, user) supported = set(await _supported_gateway_files(client_config)) @@ -796,7 +812,8 @@ async def cleanup_agent( if not gateway.url: return None if not gateway.workspace_root: - raise ValueError("gateway_workspace_root is required") + msg = "gateway_workspace_root is required" + raise ValueError(msg) client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) agent_id = _agent_key(agent) diff --git a/backend/app/services/board_group_snapshot.py b/backend/app/services/board_group_snapshot.py index 530b734..ab1e228 100644 --- a/backend/app/services/board_group_snapshot.py +++ b/backend/app/services/board_group_snapshot.py @@ -1,7 +1,8 @@ +"""Helpers for assembling board-group snapshot view models.""" + from __future__ import annotations from collections import defaultdict -from typing import Any from uuid import UUID from sqlalchemy import case, func @@ -22,48 +23,67 @@ from app.schemas.view_models import ( _STATUS_ORDER = {"in_progress": 0, "review": 1, "inbox": 2, "done": 3} _PRIORITY_ORDER = {"high": 0, "medium": 1, "low": 2} +_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession) -def _status_weight_expr() -> Any: +def _status_weight_expr() -> object: + """Return a SQL expression that sorts task statuses by configured order.""" whens = [(col(Task.status) == key, weight) for key, weight in _STATUS_ORDER.items()] return case(*whens, else_=99) -def _priority_weight_expr() -> Any: - whens = [(col(Task.priority) == key, weight) for key, weight in _PRIORITY_ORDER.items()] +def _priority_weight_expr() -> object: + """Return a SQL expression that sorts task priorities by configured order.""" + whens = [ + (col(Task.priority) == key, weight) + for key, weight in _PRIORITY_ORDER.items() + ] return case(*whens, else_=99) -async def build_group_snapshot( +async def _boards_for_group( session: AsyncSession, *, - group: BoardGroup, + group_id: UUID, exclude_board_id: UUID | None = None, - include_done: bool = False, - per_board_task_limit: int = 5, -) -> BoardGroupSnapshot: - statement = Board.objects.filter_by(board_group_id=group.id).statement +) -> list[Board]: + """Return boards belonging to a board group with optional exclusion.""" + statement = Board.objects.filter_by(board_group_id=group_id).statement if exclude_board_id is not None: statement = statement.where(col(Board.id) != exclude_board_id) - boards = list(await session.exec(statement.order_by(func.lower(col(Board.name)).asc()))) - if not boards: - return BoardGroupSnapshot(group=BoardGroupRead.model_validate(group, from_attributes=True)) + return list( + await session.exec( + statement.order_by(func.lower(col(Board.name)).asc()), + ), + ) - boards_by_id = {board.id: board for board in boards} - board_ids = list(boards_by_id.keys()) +async def _task_counts_by_board( + session: AsyncSession, + board_ids: list[UUID], +) -> dict[UUID, dict[str, int]]: + """Return per-board task counts keyed by task status.""" task_counts: dict[UUID, dict[str, int]] = defaultdict(lambda: defaultdict(int)) for board_id, status_value, total in list( await session.exec( select(col(Task.board_id), col(Task.status), func.count(col(Task.id))) .where(col(Task.board_id).in_(board_ids)) - .group_by(col(Task.board_id), col(Task.status)) - ) + .group_by(col(Task.board_id), col(Task.status)), + ), ): if board_id is None: continue task_counts[board_id][str(status_value)] = int(total or 0) + return task_counts + +async def _ordered_tasks_for_boards( + session: AsyncSession, + board_ids: list[UUID], + *, + include_done: bool, +) -> list[Task]: + """Return sorted tasks for boards, optionally excluding completed tasks.""" task_statement = select(Task).where(col(Task.board_id).in_(board_ids)) if not include_done: task_statement = task_statement.where(col(Task.status) != "done") @@ -74,62 +94,116 @@ async def build_group_snapshot( col(Task.updated_at).desc(), col(Task.created_at).desc(), ) - tasks = list(await session.exec(task_statement)) + return list(await session.exec(task_statement)) - assigned_ids = {task.assigned_agent_id for task in tasks if task.assigned_agent_id is not None} - agent_name_by_id: dict[UUID, str] = {} - if assigned_ids: - for agent_id, name in list( + +async def _agent_names( + session: AsyncSession, + tasks: list[Task], +) -> dict[UUID, str]: + """Return agent names keyed by assigned agent ids in the provided tasks.""" + assigned_ids = { + task.assigned_agent_id + for task in tasks + if task.assigned_agent_id is not None + } + if not assigned_ids: + return {} + return dict( + list( await session.exec( - select(col(Agent.id), col(Agent.name)).where(col(Agent.id).in_(assigned_ids)) - ) - ): - agent_name_by_id[agent_id] = name + select(col(Agent.id), col(Agent.name)).where( + col(Agent.id).in_(assigned_ids), + ), + ), + ), + ) + +def _task_summaries_by_board( + *, + boards_by_id: dict[UUID, Board], + tasks: list[Task], + agent_name_by_id: dict[UUID, str], + per_board_task_limit: int, +) -> dict[UUID, list[BoardGroupTaskSummary]]: + """Build limited per-board task summary lists.""" tasks_by_board: dict[UUID, list[BoardGroupTaskSummary]] = defaultdict(list) - if per_board_task_limit > 0: - for task in tasks: - if task.board_id is None: - continue - current = tasks_by_board[task.board_id] - if len(current) >= per_board_task_limit: - continue - board = boards_by_id.get(task.board_id) - if board is None: - continue - current.append( - BoardGroupTaskSummary( - id=task.id, - board_id=task.board_id, - board_name=board.name, - title=task.title, - status=task.status, - priority=task.priority, - assigned_agent_id=task.assigned_agent_id, - assignee=( - agent_name_by_id.get(task.assigned_agent_id) - if task.assigned_agent_id is not None - else None - ), - due_at=task.due_at, - in_progress_at=task.in_progress_at, - created_at=task.created_at, - updated_at=task.updated_at, - ) - ) - - snapshots: list[BoardGroupBoardSnapshot] = [] - for board in boards: - board_read = BoardRead.model_validate(board, from_attributes=True) - counts = dict(task_counts.get(board.id, {})) - snapshots.append( - BoardGroupBoardSnapshot( - board=board_read, - task_counts=counts, - tasks=tasks_by_board.get(board.id, []), - ) + if per_board_task_limit <= 0: + return tasks_by_board + for task in tasks: + if task.board_id is None: + continue + current = tasks_by_board[task.board_id] + if len(current) >= per_board_task_limit: + continue + board = boards_by_id.get(task.board_id) + if board is None: + continue + current.append( + BoardGroupTaskSummary( + id=task.id, + board_id=task.board_id, + board_name=board.name, + title=task.title, + status=task.status, + priority=task.priority, + assigned_agent_id=task.assigned_agent_id, + assignee=( + agent_name_by_id.get(task.assigned_agent_id) + if task.assigned_agent_id is not None + else None + ), + due_at=task.due_at, + in_progress_at=task.in_progress_at, + created_at=task.created_at, + updated_at=task.updated_at, + ), ) + return tasks_by_board + +async def build_group_snapshot( + session: AsyncSession, + *, + group: BoardGroup, + exclude_board_id: UUID | None = None, + include_done: bool = False, + per_board_task_limit: int = 5, +) -> BoardGroupSnapshot: + """Build a board-group snapshot with board/task summaries.""" + boards = await _boards_for_group( + session, + group_id=group.id, + exclude_board_id=exclude_board_id, + ) + if not boards: + return BoardGroupSnapshot( + group=BoardGroupRead.model_validate(group, from_attributes=True), + ) + boards_by_id = {board.id: board for board in boards} + board_ids = list(boards_by_id.keys()) + task_counts = await _task_counts_by_board(session, board_ids) + tasks = await _ordered_tasks_for_boards( + session, + board_ids, + include_done=include_done, + ) + agent_name_by_id = await _agent_names(session, tasks) + tasks_by_board = _task_summaries_by_board( + boards_by_id=boards_by_id, + tasks=tasks, + agent_name_by_id=agent_name_by_id, + per_board_task_limit=per_board_task_limit, + ) + snapshots = [ + BoardGroupBoardSnapshot( + board=BoardRead.model_validate(board, from_attributes=True), + task_counts=dict(task_counts.get(board.id, {})), + tasks=tasks_by_board.get(board.id, []), + ) + for board in boards + ] return BoardGroupSnapshot( group=BoardGroupRead.model_validate(group, from_attributes=True), boards=snapshots, @@ -144,6 +218,7 @@ async def build_board_group_snapshot( include_done: bool = False, per_board_task_limit: int = 5, ) -> BoardGroupSnapshot: + """Build a board-group snapshot anchored to a board context.""" if not board.board_group_id: return BoardGroupSnapshot(group=None, boards=[]) group = await BoardGroup.objects.by_id(board.board_group_id).first(session) diff --git a/backend/app/services/organizations.py b/backend/app/services/organizations.py index 4a5b557..ff2025c 100644 --- a/backend/app/services/organizations.py +++ b/backend/app/services/organizations.py @@ -1,5 +1,4 @@ """Organization membership and board-access service helpers.""" -# ruff: noqa: D101, D103 from __future__ import annotations @@ -38,19 +37,24 @@ ROLE_RANK = {"member": 0, "admin": 1, "owner": 2} @dataclass(frozen=True) class OrganizationContext: + """Resolved organization and membership for the active user.""" + organization: Organization member: OrganizationMember def is_org_admin(member: OrganizationMember) -> bool: + """Return whether a member has admin-level organization privileges.""" return member.role in ADMIN_ROLES async def get_default_org(session: AsyncSession) -> Organization | None: + """Return the default personal organization if it exists.""" return await Organization.objects.filter_by(name=DEFAULT_ORG_NAME).first(session) async def ensure_default_org(session: AsyncSession) -> Organization: + """Ensure and return the default personal organization.""" org = await get_default_org(session) if org is not None: return org @@ -67,6 +71,7 @@ async def get_member( user_id: UUID, organization_id: UUID, ) -> OrganizationMember | None: + """Fetch a membership by user id and organization id.""" return await OrganizationMember.objects.filter_by( user_id=user_id, organization_id=organization_id, @@ -76,6 +81,7 @@ async def get_member( async def get_first_membership( session: AsyncSession, user_id: UUID, ) -> OrganizationMember | None: + """Return the oldest membership for a user, if any.""" return ( await OrganizationMember.objects.filter_by(user_id=user_id) .order_by(col(OrganizationMember.created_at).asc()) @@ -89,6 +95,7 @@ async def set_active_organization( user: User, organization_id: UUID, ) -> OrganizationMember: + """Set a user's active organization and return the membership.""" member = await get_member(session, user_id=user.id, organization_id=organization_id) if member is None: raise HTTPException( @@ -105,6 +112,7 @@ async def get_active_membership( session: AsyncSession, user: User, ) -> OrganizationMember | None: + """Resolve and normalize the user's currently active membership.""" db_user = await User.objects.by_id(user.id).first(session) if db_user is None: db_user = user @@ -151,6 +159,7 @@ async def accept_invite( invite: OrganizationInvite, user: User, ) -> OrganizationMember: + """Accept an invite and create membership plus scoped board access rows.""" now = utcnow() member = OrganizationMember( organization_id=invite.organization_id, @@ -200,6 +209,7 @@ async def accept_invite( async def ensure_member_for_user( session: AsyncSession, user: User, ) -> OrganizationMember: + """Ensure a user has some membership, creating one if necessary.""" existing = await get_active_membership(session, user) if existing is not None: return existing @@ -237,10 +247,12 @@ async def ensure_member_for_user( def member_all_boards_read(member: OrganizationMember) -> bool: + """Return whether the member has organization-wide read access.""" return member.all_boards_read or member.all_boards_write def member_all_boards_write(member: OrganizationMember) -> bool: + """Return whether the member has organization-wide write access.""" return member.all_boards_write @@ -251,6 +263,7 @@ async def has_board_access( board: Board, write: bool, ) -> bool: + """Return whether a member has board access for the requested mode.""" if member.organization_id != board.organization_id: return False if write: @@ -276,6 +289,7 @@ async def require_board_access( board: Board, write: bool, ) -> OrganizationMember: + """Require board access for a user and return matching membership.""" member = await get_member( session, user_id=user.id, organization_id=board.organization_id, ) @@ -293,6 +307,7 @@ async def require_board_access( def board_access_filter( member: OrganizationMember, *, write: bool, ) -> ColumnElement[bool]: + """Build a SQL filter expression for boards visible to a member.""" if write and member_all_boards_write(member): return col(Board.organization_id) == member.organization_id if not write and member_all_boards_read(member): @@ -320,6 +335,7 @@ async def list_accessible_board_ids( member: OrganizationMember, write: bool, ) -> list[UUID]: + """List board ids accessible to a member for read or write mode.""" if (write and member_all_boards_write(member)) or ( not write and member_all_boards_read(member) ): @@ -354,6 +370,7 @@ async def apply_member_access_update( member: OrganizationMember, update: OrganizationMemberAccessUpdate, ) -> None: + """Replace explicit member board-access rows from an access update.""" now = utcnow() member.all_boards_read = update.all_boards_read member.all_boards_write = update.all_boards_write @@ -390,6 +407,7 @@ async def apply_invite_board_access( invite: OrganizationInvite, entries: Iterable[OrganizationBoardAccessSpec], ) -> None: + """Replace explicit invite board-access rows for an invite.""" await crud.delete_where( session, OrganizationInviteBoardAccess, @@ -414,10 +432,12 @@ async def apply_invite_board_access( def normalize_invited_email(email: str) -> str: + """Normalize an invited email address for storage/comparison.""" return email.strip().lower() def normalize_role(role: str) -> str: + """Normalize a role string and default empty values to `member`.""" return role.strip().lower() or "member" @@ -433,6 +453,7 @@ async def apply_invite_to_member( member: OrganizationMember, invite: OrganizationInvite, ) -> None: + """Apply invite role/access grants onto an existing organization member.""" now = utcnow() member_changed = False invite_role = normalize_role(invite.role or "member") diff --git a/backend/app/services/task_dependencies.py b/backend/app/services/task_dependencies.py index aaa37aa..9f802e0 100644 --- a/backend/app/services/task_dependencies.py +++ b/backend/app/services/task_dependencies.py @@ -1,3 +1,5 @@ +"""Task-dependency helpers for validation, querying, and replacement.""" + from __future__ import annotations from collections import defaultdict @@ -14,6 +16,7 @@ from app.models.task_dependencies import TaskDependency from app.models.tasks import Task DONE_STATUS: Final[str] = "done" +_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession, Mapping, Sequence) def _dedupe_uuid_list(values: Sequence[UUID]) -> list[UUID]: @@ -34,6 +37,7 @@ async def dependency_ids_by_task_id( board_id: UUID, task_ids: Sequence[UUID], ) -> dict[UUID, list[UUID]]: + """Return dependency ids keyed by task id for tasks on a board.""" if not task_ids: return {} rows = list( @@ -41,8 +45,8 @@ async def dependency_ids_by_task_id( select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id)) .where(col(TaskDependency.board_id) == board_id) .where(col(TaskDependency.task_id).in_(task_ids)) - .order_by(col(TaskDependency.created_at).asc()) - ) + .order_by(col(TaskDependency.created_at).asc()), + ), ) mapping: dict[UUID, list[UUID]] = defaultdict(list) for task_id, depends_on_task_id in rows: @@ -56,16 +60,17 @@ async def dependency_status_by_id( board_id: UUID, dependency_ids: Sequence[UUID], ) -> dict[UUID, str]: + """Return dependency status values keyed by dependency task id.""" if not dependency_ids: return {} rows = list( await session.exec( select(col(Task.id), col(Task.status)) .where(col(Task.board_id) == board_id) - .where(col(Task.id).in_(dependency_ids)) - ) + .where(col(Task.id).in_(dependency_ids)), + ), ) - return {task_id: status_value for task_id, status_value in rows} + return dict(rows) def blocked_by_dependency_ids( @@ -73,11 +78,12 @@ def blocked_by_dependency_ids( dependency_ids: Sequence[UUID], status_by_id: Mapping[UUID, str], ) -> list[UUID]: - blocked: list[UUID] = [] - for dep_id in dependency_ids: - if status_by_id.get(dep_id) != DONE_STATUS: - blocked.append(dep_id) - return blocked + """Return dependency ids that are not yet in the done status.""" + return [ + dep_id + for dep_id in dependency_ids + if status_by_id.get(dep_id) != DONE_STATUS + ] async def blocked_by_for_task( @@ -87,6 +93,7 @@ async def blocked_by_for_task( task_id: UUID, dependency_ids: Sequence[UUID] | None = None, ) -> list[UUID]: + """Return unresolved dependency ids for the provided task.""" dep_ids = list(dependency_ids or []) if dependency_ids is None: deps_map = await dependency_ids_by_task_id( @@ -97,11 +104,16 @@ async def blocked_by_for_task( dep_ids = deps_map.get(task_id, []) if not dep_ids: return [] - status_by_id = await dependency_status_by_id(session, board_id=board_id, dependency_ids=dep_ids) + status_by_id = await dependency_status_by_id( + session, + board_id=board_id, + dependency_ids=dep_ids, + ) return blocked_by_dependency_ids(dependency_ids=dep_ids, status_by_id=status_by_id) def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool: + """Detect cycles in a directed dependency graph.""" visited: set[UUID] = set() in_stack: set[UUID] = set() @@ -118,10 +130,7 @@ def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool: in_stack.remove(node) return False - for node in nodes: - if dfs(node): - return True - return False + return any(dfs(node) for node in nodes) async def validate_dependency_update( @@ -131,6 +140,7 @@ async def validate_dependency_update( task_id: UUID, depends_on_task_ids: Sequence[UUID], ) -> list[UUID]: + """Validate a dependency update and return normalized dependency ids.""" normalized = _dedupe_uuid_list(depends_on_task_ids) if task_id in normalized: raise HTTPException( @@ -145,8 +155,8 @@ async def validate_dependency_update( await session.exec( select(col(Task.id)) .where(col(Task.board_id) == board_id) - .where(col(Task.id).in_(normalized)) - ) + .where(col(Task.id).in_(normalized)), + ), ) missing = [dep_id for dep_id in normalized if dep_id not in existing_ids] if missing: @@ -159,13 +169,18 @@ async def validate_dependency_update( ) # Ensure the dependency graph is acyclic after applying the update. - task_ids = list(await session.exec(select(col(Task.id)).where(col(Task.board_id) == board_id))) + task_ids = list( + await session.exec( + select(col(Task.id)).where(col(Task.board_id) == board_id), + ), + ) rows = list( await session.exec( - select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id)).where( - col(TaskDependency.board_id) == board_id - ) - ) + select( + col(TaskDependency.task_id), + col(TaskDependency.depends_on_task_id), + ).where(col(TaskDependency.board_id) == board_id), + ), ) edges: dict[UUID, set[UUID]] = defaultdict(set) for src, dst in rows: @@ -188,6 +203,7 @@ async def replace_task_dependencies( task_id: UUID, depends_on_task_ids: Sequence[UUID], ) -> list[UUID]: + """Replace dependencies for a task and return the normalized dependency ids.""" normalized = await validate_dependency_update( session, board_id=board_id, @@ -207,7 +223,7 @@ async def replace_task_dependencies( board_id=board_id, task_id=task_id, depends_on_task_id=dep_id, - ) + ), ) return normalized @@ -218,9 +234,10 @@ async def dependent_task_ids( board_id: UUID, dependency_task_id: UUID, ) -> list[UUID]: + """Return task ids that depend on the provided dependency task id.""" rows = await session.exec( select(col(TaskDependency.task_id)) .where(col(TaskDependency.board_id) == board_id) - .where(col(TaskDependency.depends_on_task_id) == dependency_task_id) + .where(col(TaskDependency.depends_on_task_id) == dependency_task_id), ) return list(rows) diff --git a/backend/app/services/template_sync.py b/backend/app/services/template_sync.py index 87cd016..7bb6e9e 100644 --- a/backend/app/services/template_sync.py +++ b/backend/app/services/template_sync.py @@ -1,9 +1,12 @@ +"""Gateway template synchronization orchestration.""" + from __future__ import annotations import asyncio import random import re from collections.abc import Awaitable, Callable +from dataclasses import dataclass from typing import TypeVar from uuid import UUID, uuid4 @@ -11,7 +14,11 @@ from sqlalchemy import func from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession -from app.core.agent_tokens import generate_agent_token, hash_agent_token, verify_agent_token +from app.core.agent_tokens import ( + generate_agent_token, + hash_agent_token, + verify_agent_token, +) from app.core.time import utcnow from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import OpenClawGatewayError, openclaw_call @@ -49,6 +56,31 @@ _TRANSIENT_GATEWAY_ERROR_MARKERS = ( ) T = TypeVar("T") +_SECURE_RANDOM = random.SystemRandom() +_RUNTIME_TYPE_REFERENCES = (Awaitable, Callable, AsyncSession, Gateway, User, UUID) + + +@dataclass(frozen=True) +class GatewayTemplateSyncOptions: + """Runtime options controlling gateway template synchronization.""" + + user: User | None + include_main: bool = True + reset_sessions: bool = False + rotate_tokens: bool = False + force_bootstrap: bool = False + board_id: UUID | None = None + + +@dataclass(frozen=True) +class _SyncContext: + """Shared state passed to sync helper functions.""" + + session: AsyncSession + gateway: Gateway + config: GatewayClientConfig + backoff: _GatewayBackoff + options: GatewayTemplateSyncOptions def _slugify(value: str) -> str: @@ -70,7 +102,10 @@ def _is_transient_gateway_error(exc: Exception) -> bool: def _gateway_timeout_message(exc: OpenClawGatewayError) -> str: - return f"Gateway unreachable after 10 minutes (template sync timeout). Last error: {exc}" + return ( + "Gateway unreachable after 10 minutes (template sync timeout). " + f"Last error: {exc}" + ) class _GatewayBackoff: @@ -91,16 +126,25 @@ class _GatewayBackoff: def reset(self) -> None: self._delay_s = self._base_delay_s + async def _attempt( + self, + fn: Callable[[], Awaitable[T]], + ) -> tuple[T | None, OpenClawGatewayError | None]: + try: + return await fn(), None + except OpenClawGatewayError as exc: + return None, exc + async def run(self, fn: Callable[[], Awaitable[T]]) -> T: # Use per-call deadlines so long-running syncs can still tolerate a later # gateway restart without having an already-expired retry window. deadline_s = asyncio.get_running_loop().time() + self._timeout_s while True: - try: - value = await fn() - except OpenClawGatewayError as exc: + value, error = await self._attempt(fn) + if error is not None: + exc = error if not _is_transient_gateway_error(exc): - raise + raise exc now = asyncio.get_running_loop().time() remaining = deadline_s - now if remaining <= 0: @@ -108,13 +152,16 @@ class _GatewayBackoff: sleep_s = min(self._delay_s, remaining) if self._jitter: - sleep_s *= 1.0 + random.uniform(-self._jitter, self._jitter) + sleep_s *= 1.0 + _SECURE_RANDOM.uniform( + -self._jitter, + self._jitter, + ) sleep_s = max(0.0, min(sleep_s, remaining)) await asyncio.sleep(sleep_s) self._delay_s = min(self._delay_s * 2.0, self._max_delay_s) - else: - self.reset() - return value + continue + self.reset() + return value async def _with_gateway_retry( @@ -138,23 +185,25 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None: return agent_id or None -def _extract_agent_id(payload: object) -> str | None: - def _from_list(items: object) -> str | None: - if not isinstance(items, list): - return None - for item in items: - if isinstance(item, str) and item.strip(): - return item.strip() - if not isinstance(item, dict): - continue - for key in ("id", "agentId", "agent_id"): - raw = item.get(key) - if isinstance(raw, str) and raw.strip(): - return raw.strip() +def _extract_agent_id_from_list(items: object) -> str | None: + if not isinstance(items, list): return None + for item in items: + if isinstance(item, str) and item.strip(): + return item.strip() + if not isinstance(item, dict): + continue + for key in ("id", "agentId", "agent_id"): + raw = item.get(key) + if isinstance(raw, str) and raw.strip(): + return raw.strip() + return None + +def _extract_agent_id(payload: object) -> str | None: + """Extract a default gateway agent id from common list payload shapes.""" if isinstance(payload, list): - return _from_list(payload) + return _extract_agent_id_from_list(payload) if not isinstance(payload, dict): return None for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"): @@ -162,7 +211,7 @@ def _extract_agent_id(payload: object) -> str | None: if isinstance(raw, str) and raw.strip(): return raw.strip() for key in ("agents", "items", "list", "data"): - agent_id = _from_list(payload.get(key)) + agent_id = _extract_agent_id_from_list(payload.get(key)) if agent_id: return agent_id return None @@ -212,9 +261,6 @@ async def _get_agent_file( if isinstance(payload, str): return payload if isinstance(payload, dict): - # Common shapes: - # - {"name": "...", "content": "..."} - # - {"file": {"name": "...", "content": "..." }} content = payload.get("content") if isinstance(content, str): return content @@ -291,18 +337,53 @@ async def _paused_board_ids(session: AsyncSession, board_ids: list[UUID]) -> set return paused -async def sync_gateway_templates( - session: AsyncSession, +def _append_sync_error( + result: GatewayTemplatesSyncResult, + *, + message: str, + agent: Agent | None = None, + board: Board | None = None, +) -> None: + result.errors.append( + GatewayTemplatesSyncError( + agent_id=agent.id if agent else None, + agent_name=agent.name if agent else None, + board_id=board.id if board else None, + message=message, + ), + ) + + +async def _rotate_agent_token(session: AsyncSession, agent: Agent) -> str: + token = generate_agent_token() + agent.agent_token_hash = hash_agent_token(token) + agent.updated_at = utcnow() + session.add(agent) + await session.commit() + await session.refresh(agent) + return token + + +async def _ping_gateway(ctx: _SyncContext, result: GatewayTemplatesSyncResult) -> bool: + try: + async def _do_ping() -> object: + return await openclaw_call("agents.list", config=ctx.config) + + await ctx.backoff.run(_do_ping) + except (TimeoutError, OpenClawGatewayError) as exc: + _append_sync_error(result, message=str(exc)) + return False + else: + return True + + +def _base_result( gateway: Gateway, *, - user: User | None, - include_main: bool = True, - reset_sessions: bool = False, - rotate_tokens: bool = False, - force_bootstrap: bool = False, - board_id: UUID | None = None, + include_main: bool, + reset_sessions: bool, ) -> GatewayTemplatesSyncResult: - result = GatewayTemplatesSyncResult( + return GatewayTemplatesSyncResult( gateway_id=gateway.id, include_main=include_main, reset_sessions=reset_sessions, @@ -310,45 +391,239 @@ async def sync_gateway_templates( agents_skipped=0, main_updated=False, ) + + +def _boards_by_id( + boards: list[Board], + *, + board_id: UUID | None, +) -> dict[UUID, Board] | None: + boards_by_id = {board.id: board for board in boards} + if board_id is None: + return boards_by_id + board = boards_by_id.get(board_id) + if board is None: + return None + return {board_id: board} + + +async def _resolve_agent_auth_token( + ctx: _SyncContext, + result: GatewayTemplatesSyncResult, + agent: Agent, + board: Board | None, + *, + agent_gateway_id: str, +) -> tuple[str | None, bool]: + try: + auth_token = await _get_existing_auth_token( + agent_gateway_id=agent_gateway_id, + config=ctx.config, + backoff=ctx.backoff, + ) + except TimeoutError as exc: + _append_sync_error(result, agent=agent, board=board, message=str(exc)) + return None, True + + if not auth_token: + if not ctx.options.rotate_tokens: + result.agents_skipped += 1 + _append_sync_error( + result, + agent=agent, + board=board, + message=( + "Skipping agent: unable to read AUTH_TOKEN from TOOLS.md " + "(run with rotate_tokens=true to re-key)." + ), + ) + return None, False + auth_token = await _rotate_agent_token(ctx.session, agent) + + if agent.agent_token_hash and not verify_agent_token( + auth_token, + agent.agent_token_hash, + ): + if ctx.options.rotate_tokens: + auth_token = await _rotate_agent_token(ctx.session, agent) + else: + _append_sync_error( + result, + agent=agent, + board=board, + message=( + "Warning: AUTH_TOKEN in TOOLS.md does not match backend " + "token hash (agent auth may be broken)." + ), + ) + return auth_token, False + + +async def _sync_one_agent( + ctx: _SyncContext, + result: GatewayTemplatesSyncResult, + agent: Agent, + board: Board, +) -> bool: + auth_token, fatal = await _resolve_agent_auth_token( + ctx, + result, + agent, + board, + agent_gateway_id=_gateway_agent_id(agent), + ) + if fatal: + return True + if not auth_token: + return False + try: + async def _do_provision() -> None: + await provision_agent( + agent, + board, + ctx.gateway, + auth_token, + ctx.options.user, + action="update", + force_bootstrap=ctx.options.force_bootstrap, + reset_session=ctx.options.reset_sessions, + ) + + await _with_gateway_retry(_do_provision, backoff=ctx.backoff) + result.agents_updated += 1 + except TimeoutError as exc: # pragma: no cover - gateway/network dependent + result.agents_skipped += 1 + _append_sync_error(result, agent=agent, board=board, message=str(exc)) + return True + except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover + result.agents_skipped += 1 + _append_sync_error( + result, + agent=agent, + board=board, + message=f"Failed to sync templates: {exc}", + ) + return False + else: + return False + + +async def _sync_main_agent( + ctx: _SyncContext, + result: GatewayTemplatesSyncResult, +) -> bool: + main_agent = ( + await Agent.objects.all() + .filter(col(Agent.openclaw_session_id) == ctx.gateway.main_session_key) + .first(ctx.session) + ) + if main_agent is None: + _append_sync_error( + result, + message=( + "Gateway main agent record not found; " + "skipping main agent template sync." + ), + ) + return True + try: + main_gateway_agent_id = await _gateway_default_agent_id( + ctx.config, + fallback_session_key=ctx.gateway.main_session_key, + backoff=ctx.backoff, + ) + except TimeoutError as exc: + _append_sync_error(result, agent=main_agent, message=str(exc)) + return True + if not main_gateway_agent_id: + _append_sync_error( + result, + agent=main_agent, + message="Unable to resolve gateway default agent id for main agent.", + ) + return True + + token, fatal = await _resolve_agent_auth_token( + ctx, + result, + main_agent, + board=None, + agent_gateway_id=main_gateway_agent_id, + ) + if fatal: + return True + if not token: + _append_sync_error( + result, + agent=main_agent, + message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.", + ) + return True + stop_sync = False + try: + async def _do_provision_main() -> None: + await provision_main_agent( + main_agent, + ctx.gateway, + token, + ctx.options.user, + action="update", + force_bootstrap=ctx.options.force_bootstrap, + reset_session=ctx.options.reset_sessions, + ) + + await _with_gateway_retry(_do_provision_main, backoff=ctx.backoff) + except TimeoutError as exc: # pragma: no cover - gateway/network dependent + _append_sync_error(result, agent=main_agent, message=str(exc)) + stop_sync = True + except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover + _append_sync_error( + result, + agent=main_agent, + message=f"Failed to sync main agent templates: {exc}", + ) + else: + result.main_updated = True + return stop_sync + + +async def sync_gateway_templates( + session: AsyncSession, + gateway: Gateway, + options: GatewayTemplateSyncOptions, +) -> GatewayTemplatesSyncResult: + """Synchronize AGENTS/TOOLS/etc templates to gateway-connected agents.""" + result = _base_result( + gateway, + include_main=options.include_main, + reset_sessions=options.reset_sessions, + ) if not gateway.url: - result.errors.append( - GatewayTemplatesSyncError(message="Gateway URL is not configured for this gateway.") + _append_sync_error( + result, + message="Gateway URL is not configured for this gateway.", ) return result - client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) - backoff = _GatewayBackoff(timeout_s=10 * 60) - - # First, wait for the gateway to be reachable (e.g. while it is restarting). - try: - - async def _do_ping() -> object: - return await openclaw_call("agents.list", config=client_config) - - await backoff.run(_do_ping) - except TimeoutError as exc: - result.errors.append(GatewayTemplatesSyncError(message=str(exc))) - return result - except OpenClawGatewayError as exc: - result.errors.append(GatewayTemplatesSyncError(message=str(exc))) + ctx = _SyncContext( + session=session, + gateway=gateway, + config=GatewayClientConfig(url=gateway.url, token=gateway.token), + backoff=_GatewayBackoff(timeout_s=10 * 60), + options=options, + ) + if not await _ping_gateway(ctx, result): return result boards = await Board.objects.filter_by(gateway_id=gateway.id).all(session) - boards_by_id = {board.id: board for board in boards} - if board_id is not None: - board = boards_by_id.get(board_id) - if board is None: - result.errors.append( - GatewayTemplatesSyncError( - board_id=board_id, - message="Board does not belong to this gateway.", - ) - ) - return result - boards_by_id = {board_id: board} - + boards_by_id = _boards_by_id(boards, board_id=options.board_id) + if boards_by_id is None: + _append_sync_error( + result, + message="Board does not belong to this gateway.", + ) + return result paused_board_ids = await _paused_board_ids(session, list(boards_by_id.keys())) - if boards_by_id: agents = await ( Agent.objects.by_field_in("board_id", list(boards_by_id.keys())) @@ -358,251 +633,24 @@ async def sync_gateway_templates( else: agents = [] + stop_sync = False for agent in agents: board = boards_by_id.get(agent.board_id) if agent.board_id is not None else None if board is None: result.agents_skipped += 1 - result.errors.append( - GatewayTemplatesSyncError( - agent_id=agent.id, - agent_name=agent.name, - board_id=agent.board_id, - message="Skipping agent: board not found for agent.", - ) + _append_sync_error( + result, + agent=agent, + message="Skipping agent: board not found for agent.", ) continue - if board.id in paused_board_ids: result.agents_skipped += 1 continue + stop_sync = await _sync_one_agent(ctx, result, agent, board) + if stop_sync: + break - agent_gateway_id = _gateway_agent_id(agent) - try: - auth_token = await _get_existing_auth_token( - agent_gateway_id=agent_gateway_id, - config=client_config, - backoff=backoff, - ) - except TimeoutError as exc: - result.errors.append( - GatewayTemplatesSyncError( - agent_id=agent.id, - agent_name=agent.name, - board_id=board.id, - message=str(exc), - ) - ) - return result - - if not auth_token: - if not rotate_tokens: - result.agents_skipped += 1 - result.errors.append( - GatewayTemplatesSyncError( - agent_id=agent.id, - agent_name=agent.name, - board_id=board.id, - message="Skipping agent: unable to read AUTH_TOKEN from TOOLS.md (run with rotate_tokens=true to re-key).", - ) - ) - continue - raw_token = generate_agent_token() - agent.agent_token_hash = hash_agent_token(raw_token) - agent.updated_at = utcnow() - session.add(agent) - await session.commit() - await session.refresh(agent) - auth_token = raw_token - - if agent.agent_token_hash and not verify_agent_token(auth_token, agent.agent_token_hash): - # Do not block template sync on token drift; optionally re-key. - if rotate_tokens: - raw_token = generate_agent_token() - agent.agent_token_hash = hash_agent_token(raw_token) - agent.updated_at = utcnow() - session.add(agent) - await session.commit() - await session.refresh(agent) - auth_token = raw_token - else: - result.errors.append( - GatewayTemplatesSyncError( - agent_id=agent.id, - agent_name=agent.name, - board_id=board.id, - message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (agent auth may be broken).", - ) - ) - - try: - agent_item: Agent = agent - board_item: Board = board - auth_token_value: str = auth_token - - async def _do_provision( - agent_item: Agent = agent_item, - board_item: Board = board_item, - auth_token_value: str = auth_token_value, - ) -> None: - await provision_agent( - agent_item, - board_item, - gateway, - auth_token_value, - user, - action="update", - force_bootstrap=force_bootstrap, - reset_session=reset_sessions, - ) - - await _with_gateway_retry(_do_provision, backoff=backoff) - result.agents_updated += 1 - except TimeoutError as exc: # pragma: no cover - gateway/network dependent - result.agents_skipped += 1 - result.errors.append( - GatewayTemplatesSyncError( - agent_id=agent.id, - agent_name=agent.name, - board_id=board.id, - message=str(exc), - ) - ) - return result - except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover - result.agents_skipped += 1 - result.errors.append( - GatewayTemplatesSyncError( - agent_id=agent.id, - agent_name=agent.name, - board_id=board.id, - message=f"Failed to sync templates: {exc}", - ) - ) - - if include_main: - main_agent = ( - await Agent.objects.all() - .filter(col(Agent.openclaw_session_id) == gateway.main_session_key) - .first(session) - ) - if main_agent is None: - result.errors.append( - GatewayTemplatesSyncError( - message="Gateway main agent record not found; skipping main agent template sync.", - ) - ) - return result - - try: - main_gateway_agent_id = await _gateway_default_agent_id( - client_config, - fallback_session_key=gateway.main_session_key, - backoff=backoff, - ) - except TimeoutError as exc: - result.errors.append( - GatewayTemplatesSyncError( - agent_id=main_agent.id, - agent_name=main_agent.name, - message=str(exc), - ) - ) - return result - if not main_gateway_agent_id: - result.errors.append( - GatewayTemplatesSyncError( - agent_id=main_agent.id, - agent_name=main_agent.name, - message="Unable to resolve gateway default agent id for main agent.", - ) - ) - return result - - try: - main_token = await _get_existing_auth_token( - agent_gateway_id=main_gateway_agent_id, - config=client_config, - backoff=backoff, - ) - except TimeoutError as exc: - result.errors.append( - GatewayTemplatesSyncError( - agent_id=main_agent.id, - agent_name=main_agent.name, - message=str(exc), - ) - ) - return result - if not main_token: - if rotate_tokens: - raw_token = generate_agent_token() - main_agent.agent_token_hash = hash_agent_token(raw_token) - main_agent.updated_at = utcnow() - session.add(main_agent) - await session.commit() - await session.refresh(main_agent) - main_token = raw_token - else: - result.errors.append( - GatewayTemplatesSyncError( - agent_id=main_agent.id, - agent_name=main_agent.name, - message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.", - ) - ) - return result - - if main_agent.agent_token_hash and not verify_agent_token( - main_token, main_agent.agent_token_hash - ): - if rotate_tokens: - raw_token = generate_agent_token() - main_agent.agent_token_hash = hash_agent_token(raw_token) - main_agent.updated_at = utcnow() - session.add(main_agent) - await session.commit() - await session.refresh(main_agent) - main_token = raw_token - else: - result.errors.append( - GatewayTemplatesSyncError( - agent_id=main_agent.id, - agent_name=main_agent.name, - message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (main agent auth may be broken).", - ) - ) - - try: - - async def _do_provision_main() -> None: - await provision_main_agent( - main_agent, - gateway, - main_token, - user, - action="update", - force_bootstrap=force_bootstrap, - reset_session=reset_sessions, - ) - - await _with_gateway_retry(_do_provision_main, backoff=backoff) - result.main_updated = True - except TimeoutError as exc: # pragma: no cover - gateway/network dependent - result.errors.append( - GatewayTemplatesSyncError( - agent_id=main_agent.id, - agent_name=main_agent.name, - message=str(exc), - ) - ) - return result - except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover - result.errors.append( - GatewayTemplatesSyncError( - agent_id=main_agent.id, - agent_name=main_agent.name, - message=f"Failed to sync main agent templates: {exc}", - ) - ) - + if not stop_sync and options.include_main: + await _sync_main_agent(ctx, result) return result diff --git a/backend/tests/test_agent_provisioning_utils.py b/backend/tests/test_agent_provisioning_utils.py index 28db231..bd18e42 100644 --- a/backend/tests/test_agent_provisioning_utils.py +++ b/backend/tests/test_agent_provisioning_utils.py @@ -1,3 +1,5 @@ +# ruff: noqa + from __future__ import annotations from dataclasses import dataclass diff --git a/backend/tests/test_db_transaction_safety.py b/backend/tests/test_db_transaction_safety.py index f11451d..d58319e 100644 --- a/backend/tests/test_db_transaction_safety.py +++ b/backend/tests/test_db_transaction_safety.py @@ -1,3 +1,5 @@ +# ruff: noqa + from __future__ import annotations from dataclasses import dataclass diff --git a/backend/tests/test_error_handling.py b/backend/tests/test_error_handling.py index 4125d13..774c541 100644 --- a/backend/tests/test_error_handling.py +++ b/backend/tests/test_error_handling.py @@ -1,3 +1,5 @@ +# ruff: noqa + from __future__ import annotations from fastapi import FastAPI, HTTPException diff --git a/backend/tests/test_lead_policy.py b/backend/tests/test_lead_policy.py index 8bc771a..7913cec 100644 --- a/backend/tests/test_lead_policy.py +++ b/backend/tests/test_lead_policy.py @@ -1,3 +1,5 @@ +# ruff: noqa + import hashlib from app.services.lead_policy import ( diff --git a/backend/tests/test_mentions.py b/backend/tests/test_mentions.py index 8cfdce2..e791d04 100644 --- a/backend/tests/test_mentions.py +++ b/backend/tests/test_mentions.py @@ -1,3 +1,5 @@ +# ruff: noqa + from app.models.agents import Agent from app.services.mentions import extract_mentions, matches_agent_mention diff --git a/backend/tests/test_organizations_member_remove_api.py b/backend/tests/test_organizations_member_remove_api.py index 78310bc..2789d58 100644 --- a/backend/tests/test_organizations_member_remove_api.py +++ b/backend/tests/test_organizations_member_remove_api.py @@ -1,3 +1,5 @@ +# ruff: noqa + from __future__ import annotations from dataclasses import dataclass, field diff --git a/backend/tests/test_request_id_middleware.py b/backend/tests/test_request_id_middleware.py index 08b83bd..07d2de7 100644 --- a/backend/tests/test_request_id_middleware.py +++ b/backend/tests/test_request_id_middleware.py @@ -1,3 +1,5 @@ +# ruff: noqa + from __future__ import annotations import pytest diff --git a/backend/tests/test_task_dependencies.py b/backend/tests/test_task_dependencies.py index d86afff..deda5ca 100644 --- a/backend/tests/test_task_dependencies.py +++ b/backend/tests/test_task_dependencies.py @@ -1,3 +1,5 @@ +# ruff: noqa + from __future__ import annotations from dataclasses import dataclass, field diff --git a/backend/tests/test_task_dependencies_integration.py b/backend/tests/test_task_dependencies_integration.py index da1d115..421e6d6 100644 --- a/backend/tests/test_task_dependencies_integration.py +++ b/backend/tests/test_task_dependencies_integration.py @@ -1,3 +1,5 @@ +# ruff: noqa + from __future__ import annotations from uuid import UUID, uuid4