refactor: enhance docstrings for clarity and consistency across multiple files

This commit is contained in:
Abhimanyu Saharan
2026-02-09 16:23:41 +05:30
parent 7ca1899d9f
commit 7706943209
28 changed files with 1829 additions and 932 deletions

View File

@@ -1,17 +1,18 @@
"""Activity listing and task-comment feed endpoints."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json import json
from collections import deque from collections import deque
from collections.abc import AsyncIterator, Sequence from collections.abc import Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, cast from typing import TYPE_CHECKING, Any, cast
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, desc, func from sqlalchemy import asc, desc, func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from app.api.deps import ActorContext, require_admin_or_agent, require_org_member from app.api.deps import ActorContext, require_admin_or_agent, require_org_member
@@ -22,7 +23,10 @@ from app.models.activity_events import ActivityEvent
from app.models.agents import Agent from app.models.agents import Agent
from app.models.boards import Board from app.models.boards import Board
from app.models.tasks import Task from app.models.tasks import Task
from app.schemas.activity_events import ActivityEventRead, ActivityTaskCommentFeedItemRead from app.schemas.activity_events import (
ActivityEventRead,
ActivityTaskCommentFeedItemRead,
)
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.organizations import ( from app.services.organizations import (
OrganizationContext, OrganizationContext,
@@ -30,9 +34,21 @@ from app.services.organizations import (
list_accessible_board_ids, list_accessible_board_ids,
) )
if TYPE_CHECKING:
from collections.abc import AsyncIterator
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(prefix="/activity", tags=["activity"]) router = APIRouter(prefix="/activity", tags=["activity"])
SSE_SEEN_MAX = 2000 SSE_SEEN_MAX = 2000
STREAM_POLL_SECONDS = 2
SESSION_DEP = Depends(get_session)
ACTOR_DEP = Depends(require_admin_or_agent)
ORG_MEMBER_DEP = Depends(require_org_member)
BOARD_ID_QUERY = Query(default=None)
SINCE_QUERY = Query(default=None)
_RUNTIME_TYPE_REFERENCES = (UUID,)
def _parse_since(value: str | None) -> datetime | None: def _parse_since(value: str | None) -> datetime | None:
@@ -110,9 +126,10 @@ async def _fetch_task_comment_events(
@router.get("", response_model=DefaultLimitOffsetPage[ActivityEventRead]) @router.get("", response_model=DefaultLimitOffsetPage[ActivityEventRead])
async def list_activity( async def list_activity(
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = ACTOR_DEP,
) -> DefaultLimitOffsetPage[ActivityEventRead]: ) -> DefaultLimitOffsetPage[ActivityEventRead]:
"""List activity events visible to the calling actor."""
statement = select(ActivityEvent) statement = select(ActivityEvent)
if actor.actor_type == "agent" and actor.agent: if actor.actor_type == "agent" and actor.agent:
statement = statement.where(ActivityEvent.agent_id == actor.agent.id) statement = statement.where(ActivityEvent.agent_id == actor.agent.id)
@@ -124,9 +141,10 @@ async def list_activity(
if not board_ids: if not board_ids:
statement = statement.where(col(ActivityEvent.id).is_(None)) statement = statement.where(col(ActivityEvent.id).is_(None))
else: else:
statement = statement.join(Task, col(ActivityEvent.task_id) == col(Task.id)).where( statement = statement.join(
col(Task.board_id).in_(board_ids) Task,
) col(ActivityEvent.task_id) == col(Task.id),
).where(col(Task.board_id).in_(board_ids))
statement = statement.order_by(desc(col(ActivityEvent.created_at))) statement = statement.order_by(desc(col(ActivityEvent.created_at)))
return await paginate(session, statement) return await paginate(session, statement)
@@ -136,10 +154,11 @@ async def list_activity(
response_model=DefaultLimitOffsetPage[ActivityTaskCommentFeedItemRead], response_model=DefaultLimitOffsetPage[ActivityTaskCommentFeedItemRead],
) )
async def list_task_comment_feed( async def list_task_comment_feed(
board_id: UUID | None = Query(default=None), board_id: UUID | None = BOARD_ID_QUERY,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[ActivityTaskCommentFeedItemRead]: ) -> DefaultLimitOffsetPage[ActivityTaskCommentFeedItemRead]:
"""List task-comment feed items for accessible boards."""
statement = ( statement = (
select(ActivityEvent, Task, Board, Agent) select(ActivityEvent, Task, Board, Agent)
.join(Task, col(ActivityEvent.task_id) == col(Task.id)) .join(Task, col(ActivityEvent.task_id) == col(Task.id))
@@ -161,7 +180,10 @@ async def list_task_comment_feed(
def _transform(items: Sequence[Any]) -> Sequence[Any]: def _transform(items: Sequence[Any]) -> Sequence[Any]:
rows = cast(Sequence[tuple[ActivityEvent, Task, Board, Agent | None]], items) rows = cast(Sequence[tuple[ActivityEvent, Task, Board, Agent | None]], items)
return [_feed_item(event, task, board, agent) for event, task, board, agent in rows] return [
_feed_item(event, task, board, agent)
for event, task, board, agent in rows
]
return await paginate(session, statement, transformer=_transform) return await paginate(session, statement, transformer=_transform)
@@ -169,13 +191,18 @@ async def list_task_comment_feed(
@router.get("/task-comments/stream") @router.get("/task-comments/stream")
async def stream_task_comment_feed( async def stream_task_comment_feed(
request: Request, request: Request,
board_id: UUID | None = Query(default=None), board_id: UUID | None = BOARD_ID_QUERY,
since: str | None = Query(default=None), since: str | None = SINCE_QUERY,
session: AsyncSession = Depends(get_session), db_session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> EventSourceResponse: ) -> EventSourceResponse:
"""Stream task-comment events for accessible boards."""
since_dt = _parse_since(since) or utcnow() since_dt = _parse_since(since) or utcnow()
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) board_ids = await list_accessible_board_ids(
db_session,
member=ctx.member,
write=False,
)
allowed_ids = set(board_ids) allowed_ids = set(board_ids)
if board_id is not None and board_id not in allowed_ids: if board_id is not None and board_id not in allowed_ids:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -187,11 +214,15 @@ async def stream_task_comment_feed(
while True: while True:
if await request.is_disconnected(): if await request.is_disconnected():
break break
async with async_session_maker() as session: async with async_session_maker() as stream_session:
if board_id is not None: if board_id is not None:
rows = await _fetch_task_comment_events(session, last_seen, board_id=board_id) rows = await _fetch_task_comment_events(
stream_session,
last_seen,
board_id=board_id,
)
elif allowed_ids: elif allowed_ids:
rows = await _fetch_task_comment_events(session, last_seen) rows = await _fetch_task_comment_events(stream_session, last_seen)
rows = [row for row in rows if row[1].board_id in allowed_ids] rows = [row for row in rows if row[1].board_id in allowed_ids]
else: else:
rows = [] rows = []
@@ -204,10 +235,16 @@ async def stream_task_comment_feed(
if len(seen_queue) > SSE_SEEN_MAX: if len(seen_queue) > SSE_SEEN_MAX:
oldest = seen_queue.popleft() oldest = seen_queue.popleft()
seen_ids.discard(oldest) seen_ids.discard(oldest)
if event.created_at > last_seen: last_seen = max(event.created_at, last_seen)
last_seen = event.created_at payload = {
payload = {"comment": _feed_item(event, task, board, agent).model_dump(mode="json")} "comment": _feed_item(
event,
task,
board,
agent,
).model_dump(mode="json"),
}
yield {"event": "comment", "data": json.dumps(payload)} yield {"event": "comment", "data": json.dumps(payload)}
await asyncio.sleep(2) await asyncio.sleep(STREAM_POLL_SECONDS)
return EventSourceResponse(event_generator(), ping=15) return EventSourceResponse(event_generator(), ping=15)

View File

@@ -1,15 +1,16 @@
"""Approval listing, streaming, creation, and update endpoints."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json import json
from collections.abc import AsyncIterator
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, case, func, or_ from sqlalchemy import asc, case, func, or_
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from app.api.deps import ( from app.api.deps import (
@@ -23,13 +24,32 @@ from app.core.time import utcnow
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.models.approvals import Approval from app.models.approvals import Approval
from app.models.boards import Board from app.schemas.approvals import (
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus, ApprovalUpdate ApprovalCreate,
ApprovalRead,
ApprovalStatus,
ApprovalUpdate,
)
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
if TYPE_CHECKING:
from collections.abc import AsyncIterator
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.boards import Board
router = APIRouter(prefix="/boards/{board_id}/approvals", tags=["approvals"]) router = APIRouter(prefix="/boards/{board_id}/approvals", tags=["approvals"])
TASK_ID_KEYS: tuple[str, ...] = ("task_id", "taskId", "taskID") TASK_ID_KEYS: tuple[str, ...] = ("task_id", "taskId", "taskID")
STREAM_POLL_SECONDS = 2
STATUS_FILTER_QUERY = Query(default=None, alias="status")
SINCE_QUERY = Query(default=None)
BOARD_READ_DEP = Depends(get_board_for_actor_read)
BOARD_WRITE_DEP = Depends(get_board_for_actor_write)
BOARD_USER_WRITE_DEP = Depends(get_board_for_user_write)
SESSION_DEP = Depends(get_session)
ACTOR_DEP = Depends(require_admin_or_agent)
def _extract_task_id(payload: dict[str, object] | None) -> UUID | None: def _extract_task_id(payload: dict[str, object] | None) -> UUID | None:
@@ -68,7 +88,10 @@ def _approval_updated_at(approval: Approval) -> datetime:
def _serialize_approval(approval: Approval) -> dict[str, object]: def _serialize_approval(approval: Approval) -> dict[str, object]:
return ApprovalRead.model_validate(approval, from_attributes=True).model_dump(mode="json") return ApprovalRead.model_validate(
approval,
from_attributes=True,
).model_dump(mode="json")
async def _fetch_approval_events( async def _fetch_approval_events(
@@ -82,7 +105,7 @@ async def _fetch_approval_events(
or_( or_(
col(Approval.created_at) >= since, col(Approval.created_at) >= since,
col(Approval.resolved_at) >= since, col(Approval.resolved_at) >= since,
) ),
) )
.order_by(asc(col(Approval.created_at))) .order_by(asc(col(Approval.created_at)))
) )
@@ -91,11 +114,12 @@ async def _fetch_approval_events(
@router.get("", response_model=DefaultLimitOffsetPage[ApprovalRead]) @router.get("", response_model=DefaultLimitOffsetPage[ApprovalRead])
async def list_approvals( async def list_approvals(
status_filter: ApprovalStatus | None = Query(default=None, alias="status"), status_filter: ApprovalStatus | None = STATUS_FILTER_QUERY,
board: Board = Depends(get_board_for_actor_read), board: Board = BOARD_READ_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), _actor: ActorContext = ACTOR_DEP,
) -> DefaultLimitOffsetPage[ApprovalRead]: ) -> DefaultLimitOffsetPage[ApprovalRead]:
"""List approvals for a board, optionally filtering by status."""
statement = Approval.objects.filter_by(board_id=board.id) statement = Approval.objects.filter_by(board_id=board.id)
if status_filter: if status_filter:
statement = statement.filter(col(Approval.status) == status_filter) statement = statement.filter(col(Approval.status) == status_filter)
@@ -106,10 +130,11 @@ async def list_approvals(
@router.get("/stream") @router.get("/stream")
async def stream_approvals( async def stream_approvals(
request: Request, request: Request,
board: Board = Depends(get_board_for_actor_read), board: Board = BOARD_READ_DEP,
actor: ActorContext = Depends(require_admin_or_agent), _actor: ActorContext = ACTOR_DEP,
since: str | None = Query(default=None), since: str | None = SINCE_QUERY,
) -> EventSourceResponse: ) -> EventSourceResponse:
"""Stream approval updates for a board using server-sent events."""
since_dt = _parse_since(since) or utcnow() since_dt = _parse_since(since) or utcnow()
last_seen = since_dt last_seen = since_dt
@@ -125,12 +150,14 @@ async def stream_approvals(
await session.exec( await session.exec(
select(func.count(col(Approval.id))) select(func.count(col(Approval.id)))
.where(col(Approval.board_id) == board.id) .where(col(Approval.board_id) == board.id)
.where(col(Approval.status) == "pending") .where(col(Approval.status) == "pending"),
) )
).one() ).one(),
) )
task_ids = { task_ids = {
approval.task_id for approval in approvals if approval.task_id is not None approval.task_id
for approval in approvals
if approval.task_id is not None
} }
counts_by_task_id: dict[UUID, tuple[int, int]] = {} counts_by_task_id: dict[UUID, tuple[int, int]] = {}
if task_ids: if task_ids:
@@ -140,22 +167,27 @@ async def stream_approvals(
col(Approval.task_id), col(Approval.task_id),
func.count(col(Approval.id)).label("total"), func.count(col(Approval.id)).label("total"),
func.sum( func.sum(
case((col(Approval.status) == "pending", 1), else_=0) case(
(col(Approval.status) == "pending", 1),
else_=0,
),
).label("pending"), ).label("pending"),
) )
.where(col(Approval.board_id) == board.id) .where(col(Approval.board_id) == board.id)
.where(col(Approval.task_id).in_(task_ids)) .where(col(Approval.task_id).in_(task_ids))
.group_by(col(Approval.task_id)) .group_by(col(Approval.task_id)),
) ),
) )
for task_id, total, pending in rows: for task_id, total, pending in rows:
if task_id is None: if task_id is None:
continue continue
counts_by_task_id[task_id] = (int(total or 0), int(pending or 0)) counts_by_task_id[task_id] = (
int(total or 0),
int(pending or 0),
)
for approval in approvals: for approval in approvals:
updated_at = _approval_updated_at(approval) updated_at = _approval_updated_at(approval)
if updated_at > last_seen: last_seen = max(updated_at, last_seen)
last_seen = updated_at
payload: dict[str, object] = { payload: dict[str, object] = {
"approval": _serialize_approval(approval), "approval": _serialize_approval(approval),
"pending_approvals_count": pending_approvals_count, "pending_approvals_count": pending_approvals_count,
@@ -170,7 +202,7 @@ async def stream_approvals(
"approvals_pending_count": pending, "approvals_pending_count": pending,
} }
yield {"event": "approval", "data": json.dumps(payload)} yield {"event": "approval", "data": json.dumps(payload)}
await asyncio.sleep(2) await asyncio.sleep(STREAM_POLL_SECONDS)
return EventSourceResponse(event_generator(), ping=15) return EventSourceResponse(event_generator(), ping=15)
@@ -178,10 +210,11 @@ async def stream_approvals(
@router.post("", response_model=ApprovalRead) @router.post("", response_model=ApprovalRead)
async def create_approval( async def create_approval(
payload: ApprovalCreate, payload: ApprovalCreate,
board: Board = Depends(get_board_for_actor_write), board: Board = BOARD_WRITE_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), _actor: ActorContext = ACTOR_DEP,
) -> Approval: ) -> Approval:
"""Create an approval for a board."""
task_id = payload.task_id or _extract_task_id(payload.payload) task_id = payload.task_id or _extract_task_id(payload.payload)
approval = Approval( approval = Approval(
board_id=board.id, board_id=board.id,
@@ -203,9 +236,10 @@ async def create_approval(
async def update_approval( async def update_approval(
approval_id: str, approval_id: str,
payload: ApprovalUpdate, payload: ApprovalUpdate,
board: Board = Depends(get_board_for_user_write), board: Board = BOARD_USER_WRITE_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> Approval: ) -> Approval:
"""Update an approval's status and resolution timestamp."""
approval = await Approval.objects.by_id(approval_id).first(session) approval = await Approval.objects.by_id(approval_id).first(session)
if approval is None or approval.board_id != board.id: if approval is None or approval.board_id != board.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)

View File

@@ -1,15 +1,17 @@
"""Board-group memory CRUD and streaming endpoints."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json import json
from collections.abc import AsyncIterator from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import col from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from app.api.deps import ( from app.api.deps import (
@@ -24,28 +26,56 @@ from app.core.time import utcnow
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent from app.models.agents import Agent
from app.models.board_group_memory import BoardGroupMemory from app.models.board_group_memory import BoardGroupMemory
from app.models.board_groups import BoardGroup from app.models.board_groups import BoardGroup
from app.models.boards import Board from app.models.boards import Board
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.models.users import User from app.models.users import User
from app.schemas.board_group_memory import BoardGroupMemoryCreate, BoardGroupMemoryRead from app.schemas.board_group_memory import (
BoardGroupMemoryCreate,
BoardGroupMemoryRead,
)
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.mentions import extract_mentions, matches_agent_mention from app.services.mentions import extract_mentions, matches_agent_mention
from app.services.organizations import ( from app.services.organizations import (
OrganizationContext,
is_org_admin, is_org_admin,
list_accessible_board_ids, list_accessible_board_ids,
member_all_boards_read, member_all_boards_read,
member_all_boards_write, member_all_boards_write,
) )
router = APIRouter(tags=["board-group-memory"]) if TYPE_CHECKING:
from collections.abc import AsyncIterator
group_router = APIRouter(prefix="/board-groups/{group_id}/memory", tags=["board-group-memory"]) from sqlmodel.ext.asyncio.session import AsyncSession
board_router = APIRouter(prefix="/boards/{board_id}/group-memory", tags=["board-group-memory"])
from app.services.organizations import OrganizationContext
router = APIRouter(tags=["board-group-memory"])
group_router = APIRouter(
prefix="/board-groups/{group_id}/memory",
tags=["board-group-memory"],
)
board_router = APIRouter(
prefix="/boards/{board_id}/group-memory",
tags=["board-group-memory"],
)
MAX_SNIPPET_LENGTH = 800
STREAM_POLL_SECONDS = 2
SESSION_DEP = Depends(get_session)
ORG_MEMBER_DEP = Depends(require_org_member)
BOARD_READ_DEP = Depends(get_board_for_actor_read)
BOARD_WRITE_DEP = Depends(get_board_for_actor_write)
ACTOR_DEP = Depends(require_admin_or_agent)
IS_CHAT_QUERY = Query(default=None)
SINCE_QUERY = Query(default=None)
_RUNTIME_TYPE_REFERENCES = (UUID,)
def _parse_since(value: str | None) -> datetime | None: def _parse_since(value: str | None) -> datetime | None:
@@ -65,10 +95,16 @@ def _parse_since(value: str | None) -> datetime | None:
def _serialize_memory(memory: BoardGroupMemory) -> dict[str, object]: def _serialize_memory(memory: BoardGroupMemory) -> dict[str, object]:
return BoardGroupMemoryRead.model_validate(memory, from_attributes=True).model_dump(mode="json") return BoardGroupMemoryRead.model_validate(
memory,
from_attributes=True,
).model_dump(mode="json")
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None: async def _gateway_config(
session: AsyncSession,
board: Board,
) -> GatewayClientConfig | None:
if board.gateway_id is None: if board.gateway_id is None:
return None return None
gateway = await Gateway.objects.by_id(board.gateway_id).first(session) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
@@ -104,7 +140,7 @@ async def _fetch_memory_events(
if is_chat is not None: if is_chat is not None:
statement = statement.filter(col(BoardGroupMemory.is_chat) == is_chat) statement = statement.filter(col(BoardGroupMemory.is_chat) == is_chat)
statement = statement.filter(col(BoardGroupMemory.created_at) >= since).order_by( statement = statement.filter(col(BoardGroupMemory.created_at) >= since).order_by(
col(BoardGroupMemory.created_at) col(BoardGroupMemory.created_at),
) )
return await statement.all(session) return await statement.all(session)
@@ -128,19 +164,124 @@ async def _require_group_access(
return group return group
board_ids = [ board_ids = [
board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session) board.id
for board in await Board.objects.filter_by(board_group_id=group_id).all(
session,
)
] ]
if not board_ids: if not board_ids:
if is_org_admin(ctx.member): if is_org_admin(ctx.member):
return group return group
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
allowed_ids = await list_accessible_board_ids(session, member=ctx.member, write=write) allowed_ids = await list_accessible_board_ids(
session,
member=ctx.member,
write=write,
)
if not set(board_ids).intersection(set(allowed_ids)): if not set(board_ids).intersection(set(allowed_ids)):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return group return group
async def _group_read_access(
group_id: UUID,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> BoardGroup:
return await _require_group_access(session, group_id=group_id, ctx=ctx, write=False)
GROUP_READ_DEP = Depends(_group_read_access)
def _group_chat_targets(
*,
agents: list[Agent],
actor: ActorContext,
is_broadcast: bool,
mentions: set[str],
) -> dict[str, Agent]:
targets: dict[str, Agent] = {}
for agent in agents:
if not agent.openclaw_session_id:
continue
if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id:
continue
if is_broadcast or agent.is_board_lead:
targets[str(agent.id)] = agent
continue
if mentions and matches_agent_mention(agent, mentions):
targets[str(agent.id)] = agent
return targets
def _group_actor_name(actor: ActorContext) -> str:
if actor.actor_type == "agent" and actor.agent:
return actor.agent.name
if actor.user:
return actor.user.preferred_name or actor.user.name or "User"
return "User"
def _group_header(*, is_broadcast: bool, mentioned: bool) -> str:
if is_broadcast:
return "GROUP BROADCAST"
if mentioned:
return "GROUP CHAT MENTION"
return "GROUP CHAT"
@dataclass(frozen=True)
class _NotifyGroupContext:
session: AsyncSession
group: BoardGroup
board_by_id: dict[UUID, Board]
mentions: set[str]
is_broadcast: bool
actor_name: str
snippet: str
base_url: str
async def _notify_group_target(
context: _NotifyGroupContext,
agent: Agent,
) -> None:
session_key = agent.openclaw_session_id
board_id = agent.board_id
if not session_key or board_id is None:
return
board = context.board_by_id.get(board_id)
if board is None:
return
config = await _gateway_config(context.session, board)
if config is None:
return
header = _group_header(
is_broadcast=context.is_broadcast,
mentioned=matches_agent_mention(agent, context.mentions),
)
message = (
f"{header}\n"
f"Group: {context.group.name}\n"
f"From: {context.actor_name}\n\n"
f"{context.snippet}\n\n"
"Reply via group chat (shared across linked boards):\n"
f"POST {context.base_url}/api/v1/boards/{board.id}/group-memory\n"
'Body: {"content":"...","tags":["chat"]}'
)
try:
await _send_agent_message(
session_key=session_key,
config=config,
agent_name=agent.name,
message=message,
)
except OpenClawGatewayError:
return
async def _notify_group_memory_targets( async def _notify_group_memory_targets(
*, *,
session: AsyncSession, session: AsyncSession,
@@ -163,83 +304,47 @@ async def _notify_group_memory_targets(
board_ids = list(board_by_id.keys()) board_ids = list(board_by_id.keys())
agents = await Agent.objects.by_field_in("board_id", board_ids).all(session) agents = await Agent.objects.by_field_in("board_id", board_ids).all(session)
targets: dict[str, Agent] = {} targets = _group_chat_targets(
for agent in agents: agents=agents,
if not agent.openclaw_session_id: actor=actor,
continue is_broadcast=is_broadcast,
if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id: mentions=mentions,
continue )
if is_broadcast:
targets[str(agent.id)] = agent
continue
if agent.is_board_lead:
targets[str(agent.id)] = agent
continue
if mentions and matches_agent_mention(agent, mentions):
targets[str(agent.id)] = agent
if not targets: if not targets:
return return
actor_name = "User" actor_name = _group_actor_name(actor)
if actor.actor_type == "agent" and actor.agent:
actor_name = actor.agent.name
elif actor.user:
actor_name = actor.user.preferred_name or actor.user.name or actor_name
snippet = memory.content.strip() snippet = memory.content.strip()
if len(snippet) > 800: if len(snippet) > MAX_SNIPPET_LENGTH:
snippet = f"{snippet[:797]}..." snippet = f"{snippet[: MAX_SNIPPET_LENGTH - 3]}..."
base_url = settings.base_url or "http://localhost:8000" base_url = settings.base_url or "http://localhost:8000"
context = _NotifyGroupContext(
session=session,
group=group,
board_by_id=board_by_id,
mentions=mentions,
is_broadcast=is_broadcast,
actor_name=actor_name,
snippet=snippet,
base_url=base_url,
)
for agent in targets.values(): for agent in targets.values():
session_key = agent.openclaw_session_id await _notify_group_target(context, agent)
if not session_key:
continue
board_id = agent.board_id
if board_id is None:
continue
board = board_by_id.get(board_id)
if board is None:
continue
config = await _gateway_config(session, board)
if config is None:
continue
mentioned = matches_agent_mention(agent, mentions)
if is_broadcast:
header = "GROUP BROADCAST"
elif mentioned:
header = "GROUP CHAT MENTION"
else:
header = "GROUP CHAT"
message = (
f"{header}\n"
f"Group: {group.name}\n"
f"From: {actor_name}\n\n"
f"{snippet}\n\n"
"Reply via group chat (shared across linked boards):\n"
f"POST {base_url}/api/v1/boards/{board.id}/group-memory\n"
'Body: {"content":"...","tags":["chat"]}'
)
try:
await _send_agent_message(
session_key=session_key,
config=config,
agent_name=agent.name,
message=message,
)
except OpenClawGatewayError:
continue
@group_router.get("", response_model=DefaultLimitOffsetPage[BoardGroupMemoryRead]) @group_router.get("", response_model=DefaultLimitOffsetPage[BoardGroupMemoryRead])
async def list_board_group_memory( async def list_board_group_memory(
group_id: UUID, group_id: UUID,
is_chat: bool | None = Query(default=None), *,
session: AsyncSession = Depends(get_session), is_chat: bool | None = IS_CHAT_QUERY,
ctx: OrganizationContext = Depends(require_org_member), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]: ) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]:
"""List board-group memory entries for a specific group."""
await _require_group_access(session, group_id=group_id, ctx=ctx, write=False) await _require_group_access(session, group_id=group_id, ctx=ctx, write=False)
statement = ( statement = (
BoardGroupMemory.objects.filter_by(board_group_id=group_id) BoardGroupMemory.objects.filter_by(board_group_id=group_id)
@@ -255,14 +360,13 @@ async def list_board_group_memory(
@group_router.get("/stream") @group_router.get("/stream")
async def stream_board_group_memory( async def stream_board_group_memory(
group_id: UUID,
request: Request, request: Request,
since: str | None = Query(default=None), group: BoardGroup = GROUP_READ_DEP,
is_chat: bool | None = Query(default=None), *,
session: AsyncSession = Depends(get_session), since: str | None = SINCE_QUERY,
ctx: OrganizationContext = Depends(require_org_member), is_chat: bool | None = IS_CHAT_QUERY,
) -> EventSourceResponse: ) -> EventSourceResponse:
await _require_group_access(session, group_id=group_id, ctx=ctx, write=False) """Stream memory entries for a board group via server-sent events."""
since_dt = _parse_since(since) or utcnow() since_dt = _parse_since(since) or utcnow()
last_seen = since_dt last_seen = since_dt
@@ -274,16 +378,15 @@ async def stream_board_group_memory(
async with async_session_maker() as s: async with async_session_maker() as s:
memories = await _fetch_memory_events( memories = await _fetch_memory_events(
s, s,
group_id, group.id,
last_seen, last_seen,
is_chat=is_chat, is_chat=is_chat,
) )
for memory in memories: for memory in memories:
if memory.created_at > last_seen: last_seen = max(memory.created_at, last_seen)
last_seen = memory.created_at
payload = {"memory": _serialize_memory(memory)} payload = {"memory": _serialize_memory(memory)}
yield {"event": "memory", "data": json.dumps(payload)} yield {"event": "memory", "data": json.dumps(payload)}
await asyncio.sleep(2) await asyncio.sleep(STREAM_POLL_SECONDS)
return EventSourceResponse(event_generator(), ping=15) return EventSourceResponse(event_generator(), ping=15)
@@ -292,9 +395,10 @@ async def stream_board_group_memory(
async def create_board_group_memory( async def create_board_group_memory(
group_id: UUID, group_id: UUID,
payload: BoardGroupMemoryCreate, payload: BoardGroupMemoryCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> BoardGroupMemory: ) -> BoardGroupMemory:
"""Create a board-group memory entry and notify chat recipients."""
group = await _require_group_access(session, group_id=group_id, ctx=ctx, write=True) group = await _require_group_access(session, group_id=group_id, ctx=ctx, write=True)
user = await User.objects.by_id(ctx.member.user_id).first(session) user = await User.objects.by_id(ctx.member.user_id).first(session)
@@ -320,16 +424,23 @@ async def create_board_group_memory(
await session.commit() await session.commit()
await session.refresh(memory) await session.refresh(memory)
if should_notify: if should_notify:
await _notify_group_memory_targets(session=session, group=group, memory=memory, actor=actor) await _notify_group_memory_targets(
session=session,
group=group,
memory=memory,
actor=actor,
)
return memory return memory
@board_router.get("", response_model=DefaultLimitOffsetPage[BoardGroupMemoryRead]) @board_router.get("", response_model=DefaultLimitOffsetPage[BoardGroupMemoryRead])
async def list_board_group_memory_for_board( async def list_board_group_memory_for_board(
is_chat: bool | None = Query(default=None), *,
board: Board = Depends(get_board_for_actor_read), is_chat: bool | None = IS_CHAT_QUERY,
session: AsyncSession = Depends(get_session), board: Board = BOARD_READ_DEP,
session: AsyncSession = SESSION_DEP,
) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]: ) -> DefaultLimitOffsetPage[BoardGroupMemoryRead]:
"""List memory entries for the board's linked group."""
group_id = board.board_group_id group_id = board.board_group_id
if group_id is None: if group_id is None:
return await paginate(session, BoardGroupMemory.objects.by_ids([]).statement) return await paginate(session, BoardGroupMemory.objects.by_ids([]).statement)
@@ -349,10 +460,12 @@ async def list_board_group_memory_for_board(
@board_router.get("/stream") @board_router.get("/stream")
async def stream_board_group_memory_for_board( async def stream_board_group_memory_for_board(
request: Request, request: Request,
board: Board = Depends(get_board_for_actor_read), *,
since: str | None = Query(default=None), board: Board = BOARD_READ_DEP,
is_chat: bool | None = Query(default=None), since: str | None = SINCE_QUERY,
is_chat: bool | None = IS_CHAT_QUERY,
) -> EventSourceResponse: ) -> EventSourceResponse:
"""Stream memory entries for the board's linked group."""
group_id = board.board_group_id group_id = board.board_group_id
since_dt = _parse_since(since) or utcnow() since_dt = _parse_since(since) or utcnow()
last_seen = since_dt last_seen = since_dt
@@ -373,11 +486,10 @@ async def stream_board_group_memory_for_board(
is_chat=is_chat, is_chat=is_chat,
) )
for memory in memories: for memory in memories:
if memory.created_at > last_seen: last_seen = max(memory.created_at, last_seen)
last_seen = memory.created_at
payload = {"memory": _serialize_memory(memory)} payload = {"memory": _serialize_memory(memory)}
yield {"event": "memory", "data": json.dumps(payload)} yield {"event": "memory", "data": json.dumps(payload)}
await asyncio.sleep(2) await asyncio.sleep(STREAM_POLL_SECONDS)
return EventSourceResponse(event_generator(), ping=15) return EventSourceResponse(event_generator(), ping=15)
@@ -385,10 +497,11 @@ async def stream_board_group_memory_for_board(
@board_router.post("", response_model=BoardGroupMemoryRead) @board_router.post("", response_model=BoardGroupMemoryRead)
async def create_board_group_memory_for_board( async def create_board_group_memory_for_board(
payload: BoardGroupMemoryCreate, payload: BoardGroupMemoryCreate,
board: Board = Depends(get_board_for_actor_write), board: Board = BOARD_WRITE_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = ACTOR_DEP,
) -> BoardGroupMemory: ) -> BoardGroupMemory:
"""Create a group memory entry from a board context and notify recipients."""
group_id = board.board_group_id group_id = board.board_group_id
if group_id is None: if group_id is None:
raise HTTPException( raise HTTPException(
@@ -420,7 +533,12 @@ async def create_board_group_memory_for_board(
await session.commit() await session.commit()
await session.refresh(memory) await session.refresh(memory)
if should_notify: if should_notify:
await _notify_group_memory_targets(session=session, group=group, memory=memory, actor=actor) await _notify_group_memory_targets(
session=session,
group=group,
memory=memory,
actor=actor,
)
return memory return memory

View File

@@ -1,15 +1,21 @@
"""Board group CRUD, snapshot, and heartbeat endpoints."""
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Any, cast from typing import TYPE_CHECKING, Any, cast
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import ActorContext, require_admin_or_agent, require_org_admin, require_org_member from app.api.deps import (
ActorContext,
require_admin_or_agent,
require_org_admin,
require_org_member,
)
from app.core.time import utcnow from app.core.time import utcnow
from app.db import crud from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
@@ -20,7 +26,6 @@ from app.models.board_group_memory import BoardGroupMemory
from app.models.board_groups import BoardGroup from app.models.board_groups import BoardGroup
from app.models.boards import Board from app.models.boards import Board
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.models.organization_members import OrganizationMember
from app.schemas.board_group_heartbeat import ( from app.schemas.board_group_heartbeat import (
BoardGroupHeartbeatApply, BoardGroupHeartbeatApply,
BoardGroupHeartbeatApplyResult, BoardGroupHeartbeatApplyResult,
@@ -29,7 +34,10 @@ from app.schemas.board_groups import BoardGroupCreate, BoardGroupRead, BoardGrou
from app.schemas.common import OkResponse from app.schemas.common import OkResponse
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.schemas.view_models import BoardGroupSnapshot from app.schemas.view_models import BoardGroupSnapshot
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, sync_gateway_agent_heartbeats from app.services.agent_provisioning import (
DEFAULT_HEARTBEAT_CONFIG,
sync_gateway_agent_heartbeats,
)
from app.services.board_group_snapshot import build_group_snapshot from app.services.board_group_snapshot import build_group_snapshot
from app.services.organizations import ( from app.services.organizations import (
OrganizationContext, OrganizationContext,
@@ -41,7 +49,16 @@ from app.services.organizations import (
member_all_boards_write, member_all_boards_write,
) )
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.organization_members import OrganizationMember
router = APIRouter(prefix="/board-groups", tags=["board-groups"]) router = APIRouter(prefix="/board-groups", tags=["board-groups"])
SESSION_DEP = Depends(get_session)
ORG_MEMBER_DEP = Depends(require_org_member)
ORG_ADMIN_DEP = Depends(require_org_admin)
ACTOR_DEP = Depends(require_admin_or_agent)
def _slugify(value: str) -> str: def _slugify(value: str) -> str:
@@ -68,7 +85,8 @@ async def _require_group_access(
return group return group
board_ids = [ board_ids = [
board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session) board.id
for board in await Board.objects.filter_by(board_group_id=group_id).all(session)
] ]
if not board_ids: if not board_ids:
if is_org_admin(member): if is_org_admin(member):
@@ -83,14 +101,17 @@ async def _require_group_access(
@router.get("", response_model=DefaultLimitOffsetPage[BoardGroupRead]) @router.get("", response_model=DefaultLimitOffsetPage[BoardGroupRead])
async def list_board_groups( async def list_board_groups(
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[BoardGroupRead]: ) -> DefaultLimitOffsetPage[BoardGroupRead]:
"""List board groups in the active organization."""
if member_all_boards_read(ctx.member): if member_all_boards_read(ctx.member):
statement = select(BoardGroup).where(col(BoardGroup.organization_id) == ctx.organization.id) statement = select(BoardGroup).where(
col(BoardGroup.organization_id) == ctx.organization.id,
)
else: else:
accessible_boards = select(Board.board_group_id).where( accessible_boards = select(Board.board_group_id).where(
board_access_filter(ctx.member, write=False) board_access_filter(ctx.member, write=False),
) )
statement = select(BoardGroup).where( statement = select(BoardGroup).where(
col(BoardGroup.organization_id) == ctx.organization.id, col(BoardGroup.organization_id) == ctx.organization.id,
@@ -103,9 +124,10 @@ async def list_board_groups(
@router.post("", response_model=BoardGroupRead) @router.post("", response_model=BoardGroupRead)
async def create_board_group( async def create_board_group(
payload: BoardGroupCreate, payload: BoardGroupCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> BoardGroup: ) -> BoardGroup:
"""Create a board group in the active organization."""
data = payload.model_dump() data = payload.model_dump()
if not (data.get("slug") or "").strip(): if not (data.get("slug") or "").strip():
data["slug"] = _slugify(data.get("name") or "") data["slug"] = _slugify(data.get("name") or "")
@@ -116,21 +138,28 @@ async def create_board_group(
@router.get("/{group_id}", response_model=BoardGroupRead) @router.get("/{group_id}", response_model=BoardGroupRead)
async def get_board_group( async def get_board_group(
group_id: UUID, group_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> BoardGroup: ) -> BoardGroup:
return await _require_group_access(session, group_id=group_id, member=ctx.member, write=False) """Get a board group by id."""
return await _require_group_access(
session, group_id=group_id, member=ctx.member, write=False,
)
@router.get("/{group_id}/snapshot", response_model=BoardGroupSnapshot) @router.get("/{group_id}/snapshot", response_model=BoardGroupSnapshot)
async def get_board_group_snapshot( async def get_board_group_snapshot(
group_id: UUID, group_id: UUID,
*,
include_done: bool = False, include_done: bool = False,
per_board_task_limit: int = 5, per_board_task_limit: int = 5,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> BoardGroupSnapshot: ) -> BoardGroupSnapshot:
group = await _require_group_access(session, group_id=group_id, member=ctx.member, write=False) """Get a snapshot across boards in a group."""
group = await _require_group_access(
session, group_id=group_id, member=ctx.member, write=False,
)
if per_board_task_limit < 0: if per_board_task_limit < 0:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
snapshot = await build_group_snapshot( snapshot = await build_group_snapshot(
@@ -141,22 +170,22 @@ async def get_board_group_snapshot(
per_board_task_limit=per_board_task_limit, per_board_task_limit=per_board_task_limit,
) )
if not member_all_boards_read(ctx.member) and snapshot.boards: if not member_all_boards_read(ctx.member) and snapshot.boards:
allowed_ids = set(await list_accessible_board_ids(session, member=ctx.member, write=False)) allowed_ids = set(
snapshot.boards = [item for item in snapshot.boards if item.board.id in allowed_ids] await list_accessible_board_ids(session, member=ctx.member, write=False),
)
snapshot.boards = [
item for item in snapshot.boards if item.board.id in allowed_ids
]
return snapshot return snapshot
@router.post("/{group_id}/heartbeat", response_model=BoardGroupHeartbeatApplyResult) async def _authorize_heartbeat_actor(
async def apply_board_group_heartbeat( session: AsyncSession,
*,
group_id: UUID, group_id: UUID,
payload: BoardGroupHeartbeatApply, group: BoardGroup,
session: AsyncSession = Depends(get_session), actor: ActorContext,
actor: ActorContext = Depends(require_admin_or_agent), ) -> None:
) -> BoardGroupHeartbeatApplyResult:
group = await BoardGroup.objects.by_id(group_id).first(session)
if group is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if actor.actor_type == "user": if actor.actor_type == "user":
if actor.user is None: if actor.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
@@ -173,35 +202,37 @@ async def apply_board_group_heartbeat(
member=member, member=member,
write=True, write=True,
) )
elif actor.actor_type == "agent": return
agent = actor.agent agent = actor.agent
if agent is None: if agent is None or agent.board_id is None or not agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
if agent.board_id is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if not agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
board = await Board.objects.by_id(agent.board_id).first(session) board = await Board.objects.by_id(agent.board_id).first(session)
if board is None or board.board_group_id != group_id: if board is None or board.board_group_id != group_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
async def _agents_for_group_heartbeat(
session: AsyncSession,
*,
group_id: UUID,
include_board_leads: bool,
) -> tuple[dict[UUID, Board], list[Agent]]:
boards = await Board.objects.filter_by(board_group_id=group_id).all(session) boards = await Board.objects.filter_by(board_group_id=group_id).all(session)
board_by_id = {board.id: board for board in boards} board_by_id = {board.id: board for board in boards}
board_ids = list(board_by_id.keys()) board_ids = list(board_by_id.keys())
if not board_ids: if not board_ids:
return BoardGroupHeartbeatApplyResult( return board_by_id, []
board_group_id=group_id,
requested=payload.model_dump(mode="json"),
updated_agent_ids=[],
failed_agent_ids=[],
)
agents = await Agent.objects.by_field_in("board_id", board_ids).all(session) agents = await Agent.objects.by_field_in("board_id", board_ids).all(session)
if not payload.include_board_leads: if not include_board_leads:
agents = [agent for agent in agents if not agent.is_board_lead] agents = [agent for agent in agents if not agent.is_board_lead]
return board_by_id, agents
updated_agent_ids: list[UUID] = []
for agent in agents: def _update_agent_heartbeat(
*,
agent: Agent,
payload: BoardGroupHeartbeatApply,
) -> None:
raw = agent.heartbeat_config raw = agent.heartbeat_config
heartbeat: dict[str, Any] = ( heartbeat: dict[str, Any] = (
cast(dict[str, Any], dict(raw)) cast(dict[str, Any], dict(raw))
@@ -215,11 +246,14 @@ async def apply_board_group_heartbeat(
heartbeat["target"] = DEFAULT_HEARTBEAT_CONFIG.get("target", "none") heartbeat["target"] = DEFAULT_HEARTBEAT_CONFIG.get("target", "none")
agent.heartbeat_config = heartbeat agent.heartbeat_config = heartbeat
agent.updated_at = utcnow() agent.updated_at = utcnow()
session.add(agent)
updated_agent_ids.append(agent.id)
await session.commit()
async def _sync_gateway_heartbeats(
session: AsyncSession,
*,
board_by_id: dict[UUID, Board],
agents: list[Agent],
) -> list[UUID]:
agents_by_gateway_id: dict[UUID, list[Agent]] = {} agents_by_gateway_id: dict[UUID, list[Agent]] = {}
for agent in agents: for agent in agents:
board_id = agent.board_id board_id = agent.board_id
@@ -243,6 +277,51 @@ async def apply_board_group_heartbeat(
await sync_gateway_agent_heartbeats(gateway, gateway_agents) await sync_gateway_agent_heartbeats(gateway, gateway_agents)
except OpenClawGatewayError: except OpenClawGatewayError:
failed_agent_ids.extend([agent.id for agent in gateway_agents]) failed_agent_ids.extend([agent.id for agent in gateway_agents])
return failed_agent_ids
@router.post("/{group_id}/heartbeat", response_model=BoardGroupHeartbeatApplyResult)
async def apply_board_group_heartbeat(
group_id: UUID,
payload: BoardGroupHeartbeatApply,
session: AsyncSession = SESSION_DEP,
actor: ActorContext = ACTOR_DEP,
) -> BoardGroupHeartbeatApplyResult:
"""Apply heartbeat settings to agents in a board group."""
group = await BoardGroup.objects.by_id(group_id).first(session)
if group is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await _authorize_heartbeat_actor(
session,
group_id=group_id,
group=group,
actor=actor,
)
board_by_id, agents = await _agents_for_group_heartbeat(
session,
group_id=group_id,
include_board_leads=payload.include_board_leads,
)
if not agents:
return BoardGroupHeartbeatApplyResult(
board_group_id=group_id,
requested=payload.model_dump(mode="json"),
updated_agent_ids=[],
failed_agent_ids=[],
)
updated_agent_ids: list[UUID] = []
for agent in agents:
_update_agent_heartbeat(agent=agent, payload=payload)
session.add(agent)
updated_agent_ids.append(agent.id)
await session.commit()
failed_agent_ids = await _sync_gateway_heartbeats(
session,
board_by_id=board_by_id,
agents=agents,
)
return BoardGroupHeartbeatApplyResult( return BoardGroupHeartbeatApplyResult(
board_group_id=group_id, board_group_id=group_id,
@@ -256,12 +335,19 @@ async def apply_board_group_heartbeat(
async def update_board_group( async def update_board_group(
payload: BoardGroupUpdate, payload: BoardGroupUpdate,
group_id: UUID, group_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> BoardGroup: ) -> BoardGroup:
group = await _require_group_access(session, group_id=group_id, member=ctx.member, write=True) """Update a board group."""
group = await _require_group_access(
session, group_id=group_id, member=ctx.member, write=True,
)
updates = payload.model_dump(exclude_unset=True) updates = payload.model_dump(exclude_unset=True)
if "slug" in updates and updates["slug"] is not None and not updates["slug"].strip(): if (
"slug" in updates
and updates["slug"] is not None
and not updates["slug"].strip()
):
updates["slug"] = _slugify(updates.get("name") or group.name) updates["slug"] = _slugify(updates.get("name") or group.name)
updates["updated_at"] = utcnow() updates["updated_at"] = utcnow()
return await crud.patch(session, group, updates) return await crud.patch(session, group, updates)
@@ -270,10 +356,13 @@ async def update_board_group(
@router.delete("/{group_id}", response_model=OkResponse) @router.delete("/{group_id}", response_model=OkResponse)
async def delete_board_group( async def delete_board_group(
group_id: UUID, group_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OkResponse: ) -> OkResponse:
await _require_group_access(session, group_id=group_id, member=ctx.member, write=True) """Delete a board group."""
await _require_group_access(
session, group_id=group_id, member=ctx.member, write=True,
)
# Boards reference groups, so clear the FK first to keep deletes simple. # Boards reference groups, so clear the FK first to keep deletes simple.
await crud.update_where( await crud.update_where(
@@ -284,8 +373,13 @@ async def delete_board_group(
commit=False, commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, BoardGroupMemory, col(BoardGroupMemory.board_group_id) == group_id, commit=False session,
BoardGroupMemory,
col(BoardGroupMemory.board_group_id) == group_id,
commit=False,
)
await crud.delete_where(
session, BoardGroup, col(BoardGroup.id) == group_id, commit=False,
) )
await crud.delete_where(session, BoardGroup, col(BoardGroup.id) == group_id, commit=False)
await session.commit() await session.commit()
return OkResponse() return OkResponse()

View File

@@ -1,15 +1,16 @@
"""Board memory CRUD and streaming endpoints."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json import json
from collections.abc import AsyncIterator
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query, Request from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import col from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from app.api.deps import ( from app.api.deps import (
@@ -23,16 +24,35 @@ from app.core.time import utcnow
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent from app.models.agents import Agent
from app.models.board_memory import BoardMemory from app.models.board_memory import BoardMemory
from app.models.boards import Board
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.schemas.board_memory import BoardMemoryCreate, BoardMemoryRead from app.schemas.board_memory import BoardMemoryCreate, BoardMemoryRead
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.mentions import extract_mentions, matches_agent_mention from app.services.mentions import extract_mentions, matches_agent_mention
if TYPE_CHECKING:
from collections.abc import AsyncIterator
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.boards import Board
router = APIRouter(prefix="/boards/{board_id}/memory", tags=["board-memory"]) router = APIRouter(prefix="/boards/{board_id}/memory", tags=["board-memory"])
MAX_SNIPPET_LENGTH = 800
STREAM_POLL_SECONDS = 2
IS_CHAT_QUERY = Query(default=None)
SINCE_QUERY = Query(default=None)
BOARD_READ_DEP = Depends(get_board_for_actor_read)
BOARD_WRITE_DEP = Depends(get_board_for_actor_write)
SESSION_DEP = Depends(get_session)
ACTOR_DEP = Depends(require_admin_or_agent)
_RUNTIME_TYPE_REFERENCES = (UUID,)
def _parse_since(value: str | None) -> datetime | None: def _parse_since(value: str | None) -> datetime | None:
@@ -52,10 +72,16 @@ def _parse_since(value: str | None) -> datetime | None:
def _serialize_memory(memory: BoardMemory) -> dict[str, object]: def _serialize_memory(memory: BoardMemory) -> dict[str, object]:
return BoardMemoryRead.model_validate(memory, from_attributes=True).model_dump(mode="json") return BoardMemoryRead.model_validate(
memory,
from_attributes=True,
).model_dump(mode="json")
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None: async def _gateway_config(
session: AsyncSession,
board: Board,
) -> GatewayClientConfig | None:
if board.gateway_id is None: if board.gateway_id is None:
return None return None
gateway = await Gateway.objects.by_id(board.gateway_id).first(session) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
@@ -91,11 +117,67 @@ async def _fetch_memory_events(
if is_chat is not None: if is_chat is not None:
statement = statement.filter(col(BoardMemory.is_chat) == is_chat) statement = statement.filter(col(BoardMemory.is_chat) == is_chat)
statement = statement.filter(col(BoardMemory.created_at) >= since).order_by( statement = statement.filter(col(BoardMemory.created_at) >= since).order_by(
col(BoardMemory.created_at) col(BoardMemory.created_at),
) )
return await statement.all(session) return await statement.all(session)
async def _send_control_command(
*,
session: AsyncSession,
board: Board,
actor: ActorContext,
config: GatewayClientConfig,
command: str,
) -> None:
pause_targets: list[Agent] = await Agent.objects.filter_by(
board_id=board.id,
).all(
session,
)
for agent in pause_targets:
if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id:
continue
if not agent.openclaw_session_id:
continue
try:
await _send_agent_message(
session_key=agent.openclaw_session_id,
config=config,
agent_name=agent.name,
message=command,
deliver=True,
)
except OpenClawGatewayError:
continue
def _chat_targets(
*,
agents: list[Agent],
mentions: set[str],
actor: ActorContext,
) -> dict[str, Agent]:
targets: dict[str, Agent] = {}
for agent in agents:
if agent.is_board_lead:
targets[str(agent.id)] = agent
continue
if mentions and matches_agent_mention(agent, mentions):
targets[str(agent.id)] = agent
if actor.actor_type == "agent" and actor.agent:
targets.pop(str(actor.agent.id), None)
return targets
def _actor_display_name(actor: ActorContext) -> str:
if actor.actor_type == "agent" and actor.agent:
return actor.agent.name
if actor.user:
return actor.user.preferred_name or actor.user.name or "User"
return "User"
async def _notify_chat_targets( async def _notify_chat_targets(
*, *,
session: AsyncSession, session: AsyncSession,
@@ -114,44 +196,27 @@ async def _notify_chat_targets(
# Special-case control commands to reach all board agents. # Special-case control commands to reach all board agents.
# These are intended to be parsed verbatim by agent runtimes. # These are intended to be parsed verbatim by agent runtimes.
if command in {"/pause", "/resume"}: if command in {"/pause", "/resume"}:
pause_targets: list[Agent] = await Agent.objects.filter_by(board_id=board.id).all(session) await _send_control_command(
for agent in pause_targets: session=session,
if actor.actor_type == "agent" and actor.agent and agent.id == actor.agent.id: board=board,
continue actor=actor,
if not agent.openclaw_session_id:
continue
try:
await _send_agent_message(
session_key=agent.openclaw_session_id,
config=config, config=config,
agent_name=agent.name, command=command,
message=command,
deliver=True,
) )
except OpenClawGatewayError:
continue
return return
mentions = extract_mentions(memory.content) mentions = extract_mentions(memory.content)
targets: dict[str, Agent] = {} targets = _chat_targets(
for agent in await Agent.objects.filter_by(board_id=board.id).all(session): agents=await Agent.objects.filter_by(board_id=board.id).all(session),
if agent.is_board_lead: mentions=mentions,
targets[str(agent.id)] = agent actor=actor,
continue )
if mentions and matches_agent_mention(agent, mentions):
targets[str(agent.id)] = agent
if actor.actor_type == "agent" and actor.agent:
targets.pop(str(actor.agent.id), None)
if not targets: if not targets:
return return
actor_name = "User" actor_name = _actor_display_name(actor)
if actor.actor_type == "agent" and actor.agent:
actor_name = actor.agent.name
elif actor.user:
actor_name = actor.user.preferred_name or actor.user.name or actor_name
snippet = memory.content.strip() snippet = memory.content.strip()
if len(snippet) > 800: if len(snippet) > MAX_SNIPPET_LENGTH:
snippet = f"{snippet[:797]}..." snippet = f"{snippet[: MAX_SNIPPET_LENGTH - 3]}..."
base_url = settings.base_url or "http://localhost:8000" base_url = settings.base_url or "http://localhost:8000"
for agent in targets.values(): for agent in targets.values():
if not agent.openclaw_session_id: if not agent.openclaw_session_id:
@@ -180,11 +245,13 @@ async def _notify_chat_targets(
@router.get("", response_model=DefaultLimitOffsetPage[BoardMemoryRead]) @router.get("", response_model=DefaultLimitOffsetPage[BoardMemoryRead])
async def list_board_memory( async def list_board_memory(
is_chat: bool | None = Query(default=None), *,
board: Board = Depends(get_board_for_actor_read), is_chat: bool | None = IS_CHAT_QUERY,
session: AsyncSession = Depends(get_session), board: Board = BOARD_READ_DEP,
actor: ActorContext = Depends(require_admin_or_agent), session: AsyncSession = SESSION_DEP,
_actor: ActorContext = ACTOR_DEP,
) -> DefaultLimitOffsetPage[BoardMemoryRead]: ) -> DefaultLimitOffsetPage[BoardMemoryRead]:
"""List board memory entries, optionally filtering chat entries."""
statement = ( statement = (
BoardMemory.objects.filter_by(board_id=board.id) BoardMemory.objects.filter_by(board_id=board.id)
# Old/invalid rows (empty/whitespace-only content) can exist; exclude them to # Old/invalid rows (empty/whitespace-only content) can exist; exclude them to
@@ -200,11 +267,13 @@ async def list_board_memory(
@router.get("/stream") @router.get("/stream")
async def stream_board_memory( async def stream_board_memory(
request: Request, request: Request,
board: Board = Depends(get_board_for_actor_read), *,
actor: ActorContext = Depends(require_admin_or_agent), board: Board = BOARD_READ_DEP,
since: str | None = Query(default=None), _actor: ActorContext = ACTOR_DEP,
is_chat: bool | None = Query(default=None), since: str | None = SINCE_QUERY,
is_chat: bool | None = IS_CHAT_QUERY,
) -> EventSourceResponse: ) -> EventSourceResponse:
"""Stream board memory events over server-sent events."""
since_dt = _parse_since(since) or utcnow() since_dt = _parse_since(since) or utcnow()
last_seen = since_dt last_seen = since_dt
@@ -221,11 +290,10 @@ async def stream_board_memory(
is_chat=is_chat, is_chat=is_chat,
) )
for memory in memories: for memory in memories:
if memory.created_at > last_seen: last_seen = max(memory.created_at, last_seen)
last_seen = memory.created_at
payload = {"memory": _serialize_memory(memory)} payload = {"memory": _serialize_memory(memory)}
yield {"event": "memory", "data": json.dumps(payload)} yield {"event": "memory", "data": json.dumps(payload)}
await asyncio.sleep(2) await asyncio.sleep(STREAM_POLL_SECONDS)
return EventSourceResponse(event_generator(), ping=15) return EventSourceResponse(event_generator(), ping=15)
@@ -233,10 +301,11 @@ async def stream_board_memory(
@router.post("", response_model=BoardMemoryRead) @router.post("", response_model=BoardMemoryRead)
async def create_board_memory( async def create_board_memory(
payload: BoardMemoryCreate, payload: BoardMemoryCreate,
board: Board = Depends(get_board_for_actor_write), board: Board = BOARD_WRITE_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = ACTOR_DEP,
) -> BoardMemory: ) -> BoardMemory:
"""Create a board memory entry and notify chat targets when needed."""
is_chat = payload.tags is not None and "chat" in payload.tags is_chat = payload.tags is not None and "chat" in payload.tags
source = payload.source source = payload.source
if is_chat and not source: if is_chat and not source:
@@ -255,5 +324,10 @@ async def create_board_memory(
await session.commit() await session.commit()
await session.refresh(memory) await session.refresh(memory)
if is_chat: if is_chat:
await _notify_chat_targets(session=session, board=board, memory=memory, actor=actor) await _notify_chat_targets(
session=session,
board=board,
memory=memory,
actor=actor,
)
return memory return memory

View File

@@ -1,5 +1,4 @@
"""Board onboarding endpoints for user/agent collaboration.""" """Board onboarding endpoints for user/agent collaboration."""
# ruff: noqa: E501
from __future__ import annotations from __future__ import annotations
@@ -201,16 +200,22 @@ async def start_onboarding(
f"Board Name: {board.name}\n" f"Board Name: {board.name}\n"
"You are the main agent. Ask the user 6-10 focused questions total:\n" "You are the main agent. Ask the user 6-10 focused questions total:\n"
"- 3-6 questions to clarify the board goal.\n" "- 3-6 questions to clarify the board goal.\n"
"- 1 question to choose a unique name for the board lead agent (first-name style).\n" "- 1 question to choose a unique name for the board lead agent "
"- 2-4 questions to capture the user's preferences for how the board lead should work\n" "(first-name style).\n"
"- 2-4 questions to capture the user's preferences for how the board "
"lead should work\n"
" (communication style, autonomy, update cadence, and output formatting).\n" " (communication style, autonomy, update cadence, and output formatting).\n"
'- Always include a final question (and only once): "Anything else we should know?"\n' '- Always include a final question (and only once): "Anything else we '
'should know?"\n'
" (constraints, context, preferences). This MUST be the last question.\n" " (constraints, context, preferences). This MUST be the last question.\n"
' Provide an option like "Yes (I\'ll type it)" so they can enter free-text.\n' ' Provide an option like "Yes (I\'ll type it)" so they can enter free-text.\n'
" Do NOT ask for additional context on earlier questions.\n" " Do NOT ask for additional context on earlier questions.\n"
" Only include a free-text option on earlier questions if a typed answer is necessary;\n" " Only include a free-text option on earlier questions if a typed "
' when you do, make the option label include "I\'ll type it" (e.g., "Other (I\'ll type it)").\n' "answer is necessary;\n"
'- If the user sends an "Additional context" message later, incorporate it and resend status=complete\n' ' when you do, make the option label include "I\'ll type it" '
'(e.g., "Other (I\'ll type it)").\n'
'- If the user sends an "Additional context" message later, incorporate '
"it and resend status=complete\n"
" to update the draft (until the user confirms).\n" " to update the draft (until the user confirms).\n"
"Do NOT respond in OpenClaw chat.\n" "Do NOT respond in OpenClaw chat.\n"
"All onboarding responses MUST be sent to Mission Control via API.\n" "All onboarding responses MUST be sent to Mission Control via API.\n"
@@ -222,24 +227,37 @@ async def start_onboarding(
f'curl -s -X POST "{base_url}/api/v1/agent/boards/{board.id}/onboarding" ' f'curl -s -X POST "{base_url}/api/v1/agent/boards/{board.id}/onboarding" '
'-H "X-Agent-Token: $AUTH_TOKEN" ' '-H "X-Agent-Token: $AUTH_TOKEN" '
'-H "Content-Type: application/json" ' '-H "Content-Type: application/json" '
'-d \'{"question":"...","options":[{"id":"1","label":"..."},{"id":"2","label":"..."}]}\'\n' '-d \'{"question":"...","options":[{"id":"1","label":"..."},'
'{"id":"2","label":"..."}]}\'\n'
"COMPLETION example (send JSON body exactly as shown):\n" "COMPLETION example (send JSON body exactly as shown):\n"
f'curl -s -X POST "{base_url}/api/v1/agent/boards/{board.id}/onboarding" ' f'curl -s -X POST "{base_url}/api/v1/agent/boards/{board.id}/onboarding" '
'-H "X-Agent-Token: $AUTH_TOKEN" ' '-H "X-Agent-Token: $AUTH_TOKEN" '
'-H "Content-Type: application/json" ' '-H "Content-Type: application/json" '
'-d \'{"status":"complete","board_type":"goal","objective":"...","success_metrics":{"metric":"...","target":"..."},"target_date":"YYYY-MM-DD","user_profile":{"preferred_name":"...","pronouns":"...","timezone":"...","notes":"...","context":"..."},"lead_agent":{"name":"Ava","identity_profile":{"role":"Board Lead","communication_style":"direct, concise, practical","emoji":":gear:"},"autonomy_level":"balanced","verbosity":"concise","output_format":"bullets","update_cadence":"daily","custom_instructions":"..."}}\'\n' '-d \'{"status":"complete","board_type":"goal","objective":"...",'
'"success_metrics":{"metric":"...","target":"..."},'
'"target_date":"YYYY-MM-DD",'
'"user_profile":{"preferred_name":"...","pronouns":"...",'
'"timezone":"...","notes":"...","context":"..."},'
'"lead_agent":{"name":"Ava","identity_profile":{"role":"Board Lead",'
'"communication_style":"direct, concise, practical","emoji":":gear:"},'
'"autonomy_level":"balanced","verbosity":"concise",'
'"output_format":"bullets","update_cadence":"daily",'
'"custom_instructions":"..."}}\'\n'
"ENUMS:\n" "ENUMS:\n"
"- board_type: goal | general\n" "- board_type: goal | general\n"
"- lead_agent.autonomy_level: ask_first | balanced | autonomous\n" "- lead_agent.autonomy_level: ask_first | balanced | autonomous\n"
"- lead_agent.verbosity: concise | balanced | detailed\n" "- lead_agent.verbosity: concise | balanced | detailed\n"
"- lead_agent.output_format: bullets | mixed | narrative\n" "- lead_agent.output_format: bullets | mixed | narrative\n"
"- lead_agent.update_cadence: asap | hourly | daily | weekly\n" "- lead_agent.update_cadence: asap | hourly | daily | weekly\n"
"QUESTION FORMAT (one question per response, no arrays, no markdown, no extra text):\n" "QUESTION FORMAT (one question per response, no arrays, no markdown, "
"no extra text):\n"
'{"question":"...","options":[{"id":"1","label":"..."},{"id":"2","label":"..."}]}\n' '{"question":"...","options":[{"id":"1","label":"..."},{"id":"2","label":"..."}]}\n'
"Do NOT wrap questions in a list. Do NOT add commentary.\n" "Do NOT wrap questions in a list. Do NOT add commentary.\n"
"When you have enough info, send one final response with status=complete.\n" "When you have enough info, send one final response with status=complete.\n"
"The completion payload must include board_type. If board_type=goal, include objective + success_metrics.\n" "The completion payload must include board_type. If board_type=goal, "
"Also include user_profile + lead_agent to configure the board lead's working style.\n" "include objective + success_metrics.\n"
"Also include user_profile + lead_agent to configure the board lead's "
"working style.\n"
) )
try: try:

View File

@@ -1,19 +1,18 @@
"""Reusable FastAPI dependencies for auth and board/task access."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal from typing import TYPE_CHECKING, Literal
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_auth import AgentAuthContext, get_agent_auth_context_optional from app.core.agent_auth import AgentAuthContext, get_agent_auth_context_optional
from app.core.auth import AuthContext, get_auth_context, get_auth_context_optional from app.core.auth import AuthContext, get_auth_context, get_auth_context_optional
from app.db.session import get_session from app.db.session import get_session
from app.models.agents import Agent
from app.models.boards import Board from app.models.boards import Board
from app.models.organizations import Organization from app.models.organizations import Organization
from app.models.tasks import Task from app.models.tasks import Task
from app.models.users import User
from app.services.admin_access import require_admin from app.services.admin_access import require_admin
from app.services.organizations import ( from app.services.organizations import (
OrganizationContext, OrganizationContext,
@@ -23,23 +22,38 @@ from app.services.organizations import (
require_board_access, require_board_access,
) )
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
def require_admin_auth(auth: AuthContext = Depends(get_auth_context)) -> AuthContext: from app.models.agents import Agent
from app.models.users import User
AUTH_DEP = Depends(get_auth_context)
AUTH_OPTIONAL_DEP = Depends(get_auth_context_optional)
AGENT_AUTH_OPTIONAL_DEP = Depends(get_agent_auth_context_optional)
SESSION_DEP = Depends(get_session)
def require_admin_auth(auth: AuthContext = AUTH_DEP) -> AuthContext:
"""Require an authenticated admin user."""
require_admin(auth) require_admin(auth)
return auth return auth
@dataclass @dataclass
class ActorContext: class ActorContext:
"""Authenticated actor context for user or agent callers."""
actor_type: Literal["user", "agent"] actor_type: Literal["user", "agent"]
user: User | None = None user: User | None = None
agent: Agent | None = None agent: Agent | None = None
def require_admin_or_agent( def require_admin_or_agent(
auth: AuthContext | None = Depends(get_auth_context_optional), auth: AuthContext | None = AUTH_OPTIONAL_DEP,
agent_auth: AgentAuthContext | None = Depends(get_agent_auth_context_optional), agent_auth: AgentAuthContext | None = AGENT_AUTH_OPTIONAL_DEP,
) -> ActorContext: ) -> ActorContext:
"""Authorize either an admin user or an authenticated agent."""
if auth is not None: if auth is not None:
require_admin(auth) require_admin(auth)
return ActorContext(actor_type="user", user=auth.user) return ActorContext(actor_type="user", user=auth.user)
@@ -48,10 +62,14 @@ def require_admin_or_agent(
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
ACTOR_DEP = Depends(require_admin_or_agent)
async def require_org_member( async def require_org_member(
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> OrganizationContext: ) -> OrganizationContext:
"""Resolve and require active organization membership for the current user."""
if auth.user is None: if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
member = await get_active_membership(session, auth.user) member = await get_active_membership(session, auth.user)
@@ -59,15 +77,21 @@ async def require_org_member(
member = await ensure_member_for_user(session, auth.user) member = await ensure_member_for_user(session, auth.user)
if member is None: if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
organization = await Organization.objects.by_id(member.organization_id).first(session) organization = await Organization.objects.by_id(member.organization_id).first(
session,
)
if organization is None: if organization is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return OrganizationContext(organization=organization, member=member) return OrganizationContext(organization=organization, member=member)
ORG_MEMBER_DEP = Depends(require_org_member)
async def require_org_admin( async def require_org_admin(
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> OrganizationContext: ) -> OrganizationContext:
"""Require organization-admin membership privileges."""
if not is_org_admin(ctx.member): if not is_org_admin(ctx.member):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return ctx return ctx
@@ -75,8 +99,9 @@ async def require_org_admin(
async def get_board_or_404( async def get_board_or_404(
board_id: str, board_id: str,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> Board: ) -> Board:
"""Load a board by id or raise HTTP 404."""
board = await Board.objects.by_id(board_id).first(session) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -85,9 +110,10 @@ async def get_board_or_404(
async def get_board_for_actor_read( async def get_board_for_actor_read(
board_id: str, board_id: str,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = ACTOR_DEP,
) -> Board: ) -> Board:
"""Load a board and enforce actor read access."""
board = await Board.objects.by_id(board_id).first(session) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -103,9 +129,10 @@ async def get_board_for_actor_read(
async def get_board_for_actor_write( async def get_board_for_actor_write(
board_id: str, board_id: str,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = ACTOR_DEP,
) -> Board: ) -> Board:
"""Load a board and enforce actor write access."""
board = await Board.objects.by_id(board_id).first(session) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -121,9 +148,10 @@ async def get_board_for_actor_write(
async def get_board_for_user_read( async def get_board_for_user_read(
board_id: str, board_id: str,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
) -> Board: ) -> Board:
"""Load a board and enforce authenticated-user read access."""
board = await Board.objects.by_id(board_id).first(session) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -135,9 +163,10 @@ async def get_board_for_user_read(
async def get_board_for_user_write( async def get_board_for_user_write(
board_id: str, board_id: str,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
) -> Board: ) -> Board:
"""Load a board and enforce authenticated-user write access."""
board = await Board.objects.by_id(board_id).first(session) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -147,11 +176,15 @@ async def get_board_for_user_write(
return board return board
BOARD_READ_DEP = Depends(get_board_for_actor_read)
async def get_task_or_404( async def get_task_or_404(
task_id: str, task_id: str,
board: Board = Depends(get_board_for_actor_read), board: Board = BOARD_READ_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> Task: ) -> Task:
"""Load a task for a board or raise HTTP 404."""
task = await Task.objects.by_id(task_id).first(session) task = await Task.objects.by_id(task_id).first(session)
if task is None or task.board_id != board.id: if task is None or task.board_id != board.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)

View File

@@ -1,7 +1,10 @@
"""Gateway inspection and session-management endpoints."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import require_org_admin from app.api.deps import require_org_admin
from app.core.auth import AuthContext, get_auth_context from app.core.auth import AuthContext, get_auth_context
@@ -21,7 +24,6 @@ from app.integrations.openclaw_gateway_protocol import (
) )
from app.models.boards import Board from app.models.boards import Board
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.models.users import User
from app.schemas.common import OkResponse from app.schemas.common import OkResponse
from app.schemas.gateway_api import ( from app.schemas.gateway_api import (
GatewayCommandsResponse, GatewayCommandsResponse,
@@ -34,32 +36,48 @@ from app.schemas.gateway_api import (
) )
from app.services.organizations import OrganizationContext, require_board_access from app.services.organizations import OrganizationContext, require_board_access
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.users import User
router = APIRouter(prefix="/gateways", tags=["gateways"]) router = APIRouter(prefix="/gateways", tags=["gateways"])
SESSION_DEP = Depends(get_session)
AUTH_DEP = Depends(get_auth_context)
ORG_ADMIN_DEP = Depends(require_org_admin)
BOARD_ID_QUERY = Query(default=None)
RESOLVE_QUERY_DEP = Depends()
def _query_to_resolve_input(params: GatewayResolveQuery) -> GatewayResolveQuery:
return params
RESOLVE_INPUT_DEP = Depends(_query_to_resolve_input)
async def _resolve_gateway( async def _resolve_gateway(
session: AsyncSession, session: AsyncSession,
board_id: str | None, params: GatewayResolveQuery,
gateway_url: str | None,
gateway_token: str | None,
gateway_main_session_key: str | None,
*, *,
user: User | None = None, user: User | None = None,
) -> tuple[Board | None, GatewayClientConfig, str | None]: ) -> tuple[Board | None, GatewayClientConfig, str | None]:
if gateway_url: if params.gateway_url:
return ( return (
None, None,
GatewayClientConfig(url=gateway_url, token=gateway_token), GatewayClientConfig(url=params.gateway_url, token=params.gateway_token),
gateway_main_session_key, params.gateway_main_session_key,
) )
if not board_id: if not params.board_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="board_id or gateway_url is required", detail="board_id or gateway_url is required",
) )
board = await Board.objects.by_id(board_id).first(session) board = await Board.objects.by_id(params.board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found",
)
if user is not None: if user is not None:
await require_board_access(session, user=user, board=board, write=False) await require_board_access(session, user=user, board=board, write=False)
if not board.gateway_id: if not board.gateway_id:
@@ -86,14 +104,12 @@ async def _resolve_gateway(
async def _require_gateway( async def _require_gateway(
session: AsyncSession, board_id: str | None, *, user: User | None = None session: AsyncSession, board_id: str | None, *, user: User | None = None,
) -> tuple[Board, GatewayClientConfig, str | None]: ) -> tuple[Board, GatewayClientConfig, str | None]:
params = GatewayResolveQuery(board_id=board_id)
board, config, main_session = await _resolve_gateway( board, config, main_session = await _resolve_gateway(
session, session,
board_id, params,
None,
None,
None,
user=user, user=user,
) )
if board is None: if board is None:
@@ -106,17 +122,15 @@ async def _require_gateway(
@router.get("/status", response_model=GatewaysStatusResponse) @router.get("/status", response_model=GatewaysStatusResponse)
async def gateways_status( async def gateways_status(
params: GatewayResolveQuery = Depends(), params: GatewayResolveQuery = RESOLVE_INPUT_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> GatewaysStatusResponse: ) -> GatewaysStatusResponse:
"""Return gateway connectivity and session status."""
board, config, main_session = await _resolve_gateway( board, config, main_session = await _resolve_gateway(
session, session,
params.board_id, params,
params.gateway_url,
params.gateway_token,
params.gateway_main_session_key,
user=auth.user, user=auth.user,
) )
if board is not None and board.organization_id != ctx.organization.id: if board is not None and board.organization_id != ctx.organization.id:
@@ -131,7 +145,9 @@ async def gateways_status(
main_session_error: str | None = None main_session_error: str | None = None
if main_session: if main_session:
try: try:
ensured = await ensure_session(main_session, config=config, label="Main Agent") ensured = await ensure_session(
main_session, config=config, label="Main Agent",
)
if isinstance(ensured, dict): if isinstance(ensured, dict):
main_session_entry = ensured.get("entry") or ensured main_session_entry = ensured.get("entry") or ensured
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
@@ -146,22 +162,23 @@ async def gateways_status(
main_session_error=main_session_error, main_session_error=main_session_error,
) )
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
return GatewaysStatusResponse(connected=False, gateway_url=config.url, error=str(exc)) return GatewaysStatusResponse(
connected=False, gateway_url=config.url, error=str(exc),
)
@router.get("/sessions", response_model=GatewaySessionsResponse) @router.get("/sessions", response_model=GatewaySessionsResponse)
async def list_gateway_sessions( async def list_gateway_sessions(
board_id: str | None = Query(default=None), board_id: str | None = BOARD_ID_QUERY,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> GatewaySessionsResponse: ) -> GatewaySessionsResponse:
"""List sessions for a gateway associated with a board."""
params = GatewayResolveQuery(board_id=board_id)
board, config, main_session = await _resolve_gateway( board, config, main_session = await _resolve_gateway(
session, session,
board_id, params,
None,
None,
None,
user=auth.user, user=auth.user,
) )
if board is not None and board.organization_id != ctx.organization.id: if board is not None and board.organization_id != ctx.organization.id:
@@ -169,7 +186,9 @@ async def list_gateway_sessions(
try: try:
sessions = await openclaw_call("sessions.list", config=config) sessions = await openclaw_call("sessions.list", config=config)
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
if isinstance(sessions, dict): if isinstance(sessions, dict):
sessions_list = list(sessions.get("sessions") or []) sessions_list = list(sessions.get("sessions") or [])
else: else:
@@ -178,7 +197,9 @@ async def list_gateway_sessions(
main_session_entry: object | None = None main_session_entry: object | None = None
if main_session: if main_session:
try: try:
ensured = await ensure_session(main_session, config=config, label="Main Agent") ensured = await ensure_session(
main_session, config=config, label="Main Agent",
)
if isinstance(ensured, dict): if isinstance(ensured, dict):
main_session_entry = ensured.get("entry") or ensured main_session_entry = ensured.get("entry") or ensured
except OpenClawGatewayError: except OpenClawGatewayError:
@@ -191,70 +212,103 @@ async def list_gateway_sessions(
) )
async def _list_sessions(config: GatewayClientConfig) -> list[dict[str, object]]:
sessions = await openclaw_call("sessions.list", config=config)
if isinstance(sessions, dict):
raw_items = sessions.get("sessions") or []
else:
raw_items = sessions or []
return [
item
for item in raw_items
if isinstance(item, dict)
]
async def _with_main_session(
sessions_list: list[dict[str, object]],
*,
config: GatewayClientConfig,
main_session: str | None,
) -> list[dict[str, object]]:
if not main_session or any(
item.get("key") == main_session for item in sessions_list
):
return sessions_list
try:
await ensure_session(main_session, config=config, label="Main Agent")
return await _list_sessions(config)
except OpenClawGatewayError:
return sessions_list
@router.get("/sessions/{session_id}", response_model=GatewaySessionResponse) @router.get("/sessions/{session_id}", response_model=GatewaySessionResponse)
async def get_gateway_session( async def get_gateway_session(
session_id: str, session_id: str,
board_id: str | None = Query(default=None), board_id: str | None = BOARD_ID_QUERY,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> GatewaySessionResponse: ) -> GatewaySessionResponse:
"""Get a specific gateway session by key."""
params = GatewayResolveQuery(board_id=board_id)
board, config, main_session = await _resolve_gateway( board, config, main_session = await _resolve_gateway(
session, session,
board_id, params,
None,
None,
None,
user=auth.user, user=auth.user,
) )
if board is not None and board.organization_id != ctx.organization.id: if board is not None and board.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
try: try:
sessions = await openclaw_call("sessions.list", config=config) sessions_list = await _list_sessions(config)
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
if isinstance(sessions, dict): status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
sessions_list = list(sessions.get("sessions") or []) ) from exc
else: sessions_list = await _with_main_session(
sessions_list = list(sessions or []) sessions_list,
if main_session and not any(item.get("key") == main_session for item in sessions_list): config=config,
try: main_session=main_session,
await ensure_session(main_session, config=config, label="Main Agent") )
refreshed = await openclaw_call("sessions.list", config=config) session_entry = next(
if isinstance(refreshed, dict): (item for item in sessions_list if item.get("key") == session_id), None,
sessions_list = list(refreshed.get("sessions") or []) )
else:
sessions_list = list(refreshed or [])
except OpenClawGatewayError:
pass
session_entry = next((item for item in sessions_list if item.get("key") == session_id), None)
if session_entry is None and main_session and session_id == main_session: if session_entry is None and main_session and session_id == main_session:
try: try:
ensured = await ensure_session(main_session, config=config, label="Main Agent") ensured = await ensure_session(
main_session, config=config, label="Main Agent",
)
if isinstance(ensured, dict): if isinstance(ensured, dict):
session_entry = ensured.get("entry") or ensured session_entry = ensured.get("entry") or ensured
except OpenClawGatewayError: except OpenClawGatewayError:
session_entry = None session_entry = None
if session_entry is None: if session_entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found",
)
return GatewaySessionResponse(session=session_entry) return GatewaySessionResponse(session=session_entry)
@router.get("/sessions/{session_id}/history", response_model=GatewaySessionHistoryResponse) @router.get(
"/sessions/{session_id}/history", response_model=GatewaySessionHistoryResponse,
)
async def get_session_history( async def get_session_history(
session_id: str, session_id: str,
board_id: str | None = Query(default=None), board_id: str | None = BOARD_ID_QUERY,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> GatewaySessionHistoryResponse: ) -> GatewaySessionHistoryResponse:
"""Fetch chat history for a gateway session."""
board, config, _ = await _require_gateway(session, board_id, user=auth.user) board, config, _ = await _require_gateway(session, board_id, user=auth.user)
if board.organization_id != ctx.organization.id: if board.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
try: try:
history = await get_chat_history(session_id, config=config) history = await get_chat_history(session_id, config=config)
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
if isinstance(history, dict) and isinstance(history.get("messages"), list): if isinstance(history, dict) and isinstance(history.get("messages"), list):
return GatewaySessionHistoryResponse(history=history["messages"]) return GatewaySessionHistoryResponse(history=history["messages"])
return GatewaySessionHistoryResponse(history=list(history or [])) return GatewaySessionHistoryResponse(history=list(history or []))
@@ -264,14 +318,14 @@ async def get_session_history(
async def send_gateway_session_message( async def send_gateway_session_message(
session_id: str, session_id: str,
payload: GatewaySessionMessageRequest, payload: GatewaySessionMessageRequest,
board_id: str | None = Query(default=None), board_id: str | None = BOARD_ID_QUERY,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
ctx: OrganizationContext = Depends(require_org_admin),
) -> OkResponse: ) -> OkResponse:
board, config, main_session = await _require_gateway(session, board_id, user=auth.user) """Send a message into a specific gateway session."""
if board.organization_id != ctx.organization.id: board, config, main_session = await _require_gateway(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) session, board_id, user=auth.user,
)
if auth.user is None: if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
await require_board_access(session, user=auth.user, board=board, write=True) await require_board_access(session, user=auth.user, board=board, write=True)
@@ -280,15 +334,18 @@ async def send_gateway_session_message(
await ensure_session(main_session, config=config, label="Main Agent") await ensure_session(main_session, config=config, label="Main Agent")
await send_message(payload.content, session_key=session_id, config=config) await send_message(payload.content, session_key=session_id, config=config)
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
return OkResponse() return OkResponse()
@router.get("/commands", response_model=GatewayCommandsResponse) @router.get("/commands", response_model=GatewayCommandsResponse)
async def gateway_commands( async def gateway_commands(
auth: AuthContext = Depends(get_auth_context), _auth: AuthContext = AUTH_DEP,
_ctx: OrganizationContext = Depends(require_org_admin), _ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> GatewayCommandsResponse: ) -> GatewayCommandsResponse:
"""Return supported gateway protocol methods and events."""
return GatewayCommandsResponse( return GatewayCommandsResponse(
protocol_version=PROTOCOL_VERSION, protocol_version=PROTOCOL_VERSION,
methods=GATEWAY_METHODS, methods=GATEWAY_METHODS,

View File

@@ -1,10 +1,13 @@
"""Gateway CRUD and template synchronization endpoints."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import col from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import require_org_admin from app.api.deps import require_org_admin
from app.core.agent_tokens import generate_agent_token, hash_agent_token from app.core.agent_tokens import generate_agent_token, hash_agent_token
@@ -14,7 +17,11 @@ from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import get_session from app.db.session import get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent from app.models.agents import Agent
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.schemas.common import OkResponse from app.schemas.common import OkResponse
@@ -25,11 +32,61 @@ from app.schemas.gateways import (
GatewayUpdate, GatewayUpdate,
) )
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_main_agent from app.services.agent_provisioning import (
from app.services.organizations import OrganizationContext DEFAULT_HEARTBEAT_CONFIG,
from app.services.template_sync import sync_gateway_templates as sync_gateway_templates_service provision_main_agent,
)
from app.services.template_sync import (
GatewayTemplateSyncOptions,
)
from app.services.template_sync import (
sync_gateway_templates as sync_gateway_templates_service,
)
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.services.organizations import OrganizationContext
router = APIRouter(prefix="/gateways", tags=["gateways"]) router = APIRouter(prefix="/gateways", tags=["gateways"])
SESSION_DEP = Depends(get_session)
AUTH_DEP = Depends(get_auth_context)
ORG_ADMIN_DEP = Depends(require_org_admin)
INCLUDE_MAIN_QUERY = Query(default=True)
RESET_SESSIONS_QUERY = Query(default=False)
ROTATE_TOKENS_QUERY = Query(default=False)
FORCE_BOOTSTRAP_QUERY = Query(default=False)
BOARD_ID_QUERY = Query(default=None)
_RUNTIME_TYPE_REFERENCES = (UUID,)
@dataclass(frozen=True)
class _TemplateSyncQuery:
include_main: bool
reset_sessions: bool
rotate_tokens: bool
force_bootstrap: bool
board_id: UUID | None
def _template_sync_query(
*,
include_main: bool = INCLUDE_MAIN_QUERY,
reset_sessions: bool = RESET_SESSIONS_QUERY,
rotate_tokens: bool = ROTATE_TOKENS_QUERY,
force_bootstrap: bool = FORCE_BOOTSTRAP_QUERY,
board_id: UUID | None = BOARD_ID_QUERY,
) -> _TemplateSyncQuery:
return _TemplateSyncQuery(
include_main=include_main,
reset_sessions=reset_sessions,
rotate_tokens=rotate_tokens,
force_bootstrap=force_bootstrap,
board_id=board_id,
)
SYNC_QUERY_DEP = Depends(_template_sync_query)
def _main_agent_name(gateway: Gateway) -> str: def _main_agent_name(gateway: Gateway) -> str:
@@ -48,7 +105,9 @@ async def _require_gateway(
.first(session) .first(session)
) )
if gateway is None: if gateway is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found",
)
return gateway return gateway
@@ -59,14 +118,18 @@ async def _find_main_agent(
previous_session_key: str | None = None, previous_session_key: str | None = None,
) -> Agent | None: ) -> Agent | None:
if gateway.main_session_key: if gateway.main_session_key:
agent = await Agent.objects.filter_by(openclaw_session_id=gateway.main_session_key).first( agent = await Agent.objects.filter_by(
session openclaw_session_id=gateway.main_session_key,
).first(
session,
) )
if agent: if agent:
return agent return agent
if previous_session_key: if previous_session_key:
agent = await Agent.objects.filter_by(openclaw_session_id=previous_session_key).first( agent = await Agent.objects.filter_by(
session openclaw_session_id=previous_session_key,
).first(
session,
) )
if agent: if agent:
return agent return agent
@@ -85,13 +148,17 @@ async def _ensure_main_agent(
gateway: Gateway, gateway: Gateway,
auth: AuthContext, auth: AuthContext,
*, *,
previous_name: str | None = None, previous: tuple[str | None, str | None] | None = None,
previous_session_key: str | None = None,
action: str = "provision", action: str = "provision",
) -> Agent | None: ) -> Agent | None:
if not gateway.url or not gateway.main_session_key: if not gateway.url or not gateway.main_session_key:
return None return None
agent = await _find_main_agent(session, gateway, previous_name, previous_session_key) agent = await _find_main_agent(
session,
gateway,
previous_name=previous[0] if previous else None,
previous_session_key=previous[1] if previous else None,
)
if agent is None: if agent is None:
agent = Agent( agent = Agent(
name=_main_agent_name(gateway), name=_main_agent_name(gateway),
@@ -130,7 +197,8 @@ async def _ensure_main_agent(
( (
f"Hello {agent.name}. Your gateway provisioning was updated.\n\n" f"Hello {agent.name}. Your gateway provisioning was updated.\n\n"
"Please re-read AGENTS.md, USER.md, HEARTBEAT.md, and TOOLS.md. " "Please re-read AGENTS.md, USER.md, HEARTBEAT.md, and TOOLS.md. "
"If BOOTSTRAP.md exists, run it once then delete it. Begin heartbeats after startup." "If BOOTSTRAP.md exists, run it once then delete it. "
"Begin heartbeats after startup."
), ),
session_key=gateway.main_session_key, session_key=gateway.main_session_key,
config=GatewayClientConfig(url=gateway.url, token=gateway.token), config=GatewayClientConfig(url=gateway.url, token=gateway.token),
@@ -144,9 +212,10 @@ async def _ensure_main_agent(
@router.get("", response_model=DefaultLimitOffsetPage[GatewayRead]) @router.get("", response_model=DefaultLimitOffsetPage[GatewayRead])
async def list_gateways( async def list_gateways(
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> DefaultLimitOffsetPage[GatewayRead]: ) -> DefaultLimitOffsetPage[GatewayRead]:
"""List gateways for the caller's organization."""
statement = ( statement = (
Gateway.objects.filter_by(organization_id=ctx.organization.id) Gateway.objects.filter_by(organization_id=ctx.organization.id)
.order_by(col(Gateway.created_at).desc()) .order_by(col(Gateway.created_at).desc())
@@ -158,10 +227,11 @@ async def list_gateways(
@router.post("", response_model=GatewayRead) @router.post("", response_model=GatewayRead)
async def create_gateway( async def create_gateway(
payload: GatewayCreate, payload: GatewayCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> Gateway: ) -> Gateway:
"""Create a gateway and provision or refresh its main agent."""
data = payload.model_dump() data = payload.model_dump()
data["organization_id"] = ctx.organization.id data["organization_id"] = ctx.organization.id
gateway = await crud.create(session, Gateway, **data) gateway = await crud.create(session, Gateway, **data)
@@ -172,9 +242,10 @@ async def create_gateway(
@router.get("/{gateway_id}", response_model=GatewayRead) @router.get("/{gateway_id}", response_model=GatewayRead)
async def get_gateway( async def get_gateway(
gateway_id: UUID, gateway_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> Gateway: ) -> Gateway:
"""Return one gateway by id for the caller's organization."""
return await _require_gateway( return await _require_gateway(
session, session,
gateway_id=gateway_id, gateway_id=gateway_id,
@@ -186,10 +257,11 @@ async def get_gateway(
async def update_gateway( async def update_gateway(
gateway_id: UUID, gateway_id: UUID,
payload: GatewayUpdate, payload: GatewayUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> Gateway: ) -> Gateway:
"""Patch a gateway and refresh the main-agent provisioning state."""
gateway = await _require_gateway( gateway = await _require_gateway(
session, session,
gateway_id=gateway_id, gateway_id=gateway_id,
@@ -203,8 +275,7 @@ async def update_gateway(
session, session,
gateway, gateway,
auth, auth,
previous_name=previous_name, previous=(previous_name, previous_session_key),
previous_session_key=previous_session_key,
action="update", action="update",
) )
return gateway return gateway
@@ -213,15 +284,12 @@ async def update_gateway(
@router.post("/{gateway_id}/templates/sync", response_model=GatewayTemplatesSyncResult) @router.post("/{gateway_id}/templates/sync", response_model=GatewayTemplatesSyncResult)
async def sync_gateway_templates( async def sync_gateway_templates(
gateway_id: UUID, gateway_id: UUID,
include_main: bool = Query(default=True), sync_query: _TemplateSyncQuery = SYNC_QUERY_DEP,
reset_sessions: bool = Query(default=False), session: AsyncSession = SESSION_DEP,
rotate_tokens: bool = Query(default=False), auth: AuthContext = AUTH_DEP,
force_bootstrap: bool = Query(default=False), ctx: OrganizationContext = ORG_ADMIN_DEP,
board_id: UUID | None = Query(default=None),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context),
ctx: OrganizationContext = Depends(require_org_admin),
) -> GatewayTemplatesSyncResult: ) -> GatewayTemplatesSyncResult:
"""Sync templates for a gateway and optionally rotate runtime settings."""
gateway = await _require_gateway( gateway = await _require_gateway(
session, session,
gateway_id=gateway_id, gateway_id=gateway_id,
@@ -230,21 +298,24 @@ async def sync_gateway_templates(
return await sync_gateway_templates_service( return await sync_gateway_templates_service(
session, session,
gateway, gateway,
GatewayTemplateSyncOptions(
user=auth.user, user=auth.user,
include_main=include_main, include_main=sync_query.include_main,
reset_sessions=reset_sessions, reset_sessions=sync_query.reset_sessions,
rotate_tokens=rotate_tokens, rotate_tokens=sync_query.rotate_tokens,
force_bootstrap=force_bootstrap, force_bootstrap=sync_query.force_bootstrap,
board_id=board_id, board_id=sync_query.board_id,
),
) )
@router.delete("/{gateway_id}", response_model=OkResponse) @router.delete("/{gateway_id}", response_model=OkResponse)
async def delete_gateway( async def delete_gateway(
gateway_id: UUID, gateway_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OkResponse: ) -> OkResponse:
"""Delete a gateway in the caller's organization."""
gateway = await _require_gateway( gateway = await _require_gateway(
session, session,
gateway_id=gateway_id, gateway_id=gateway_id,

View File

@@ -1,3 +1,5 @@
"""Dashboard metric aggregation endpoints."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
@@ -32,10 +34,16 @@ router = APIRouter(prefix="/metrics", tags=["metrics"])
OFFLINE_AFTER = timedelta(minutes=10) OFFLINE_AFTER = timedelta(minutes=10)
ERROR_EVENT_PATTERN = "%failed" ERROR_EVENT_PATTERN = "%failed"
_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession)
RANGE_QUERY = Query(default="24h")
SESSION_DEP = Depends(get_session)
ORG_MEMBER_DEP = Depends(require_org_member)
@dataclass(frozen=True) @dataclass(frozen=True)
class RangeSpec: class RangeSpec:
"""Resolved time-range specification for metric aggregation."""
key: Literal["24h", "7d"] key: Literal["24h", "7d"]
start: datetime start: datetime
end: datetime end: datetime
@@ -80,7 +88,8 @@ def _build_buckets(range_spec: RangeSpec) -> list[datetime]:
def _series_from_mapping( def _series_from_mapping(
range_spec: RangeSpec, mapping: dict[datetime, float] range_spec: RangeSpec,
mapping: dict[datetime, float],
) -> DashboardRangeSeries: ) -> DashboardRangeSeries:
points = [ points = [
DashboardSeriesPoint(period=bucket, value=float(mapping.get(bucket, 0))) DashboardSeriesPoint(period=bucket, value=float(mapping.get(bucket, 0)))
@@ -94,7 +103,8 @@ def _series_from_mapping(
def _wip_series_from_mapping( def _wip_series_from_mapping(
range_spec: RangeSpec, mapping: dict[datetime, dict[str, int]] range_spec: RangeSpec,
mapping: dict[datetime, dict[str, int]],
) -> DashboardWipRangeSeries: ) -> DashboardWipRangeSeries:
points: list[DashboardWipPoint] = [] points: list[DashboardWipPoint] = []
for bucket in _build_buckets(range_spec): for bucket in _build_buckets(range_spec):
@@ -105,7 +115,7 @@ def _wip_series_from_mapping(
inbox=values.get("inbox", 0), inbox=values.get("inbox", 0),
in_progress=values.get("in_progress", 0), in_progress=values.get("in_progress", 0),
review=values.get("review", 0), review=values.get("review", 0),
) ),
) )
return DashboardWipRangeSeries( return DashboardWipRangeSeries(
range=range_spec.key, range=range_spec.key,
@@ -115,7 +125,9 @@ def _wip_series_from_mapping(
async def _query_throughput( async def _query_throughput(
session: AsyncSession, range_spec: RangeSpec, board_ids: list[UUID] session: AsyncSession,
range_spec: RangeSpec,
board_ids: list[UUID],
) -> DashboardRangeSeries: ) -> DashboardRangeSeries:
bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket") bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket")
statement = ( statement = (
@@ -135,7 +147,9 @@ async def _query_throughput(
async def _query_cycle_time( async def _query_cycle_time(
session: AsyncSession, range_spec: RangeSpec, board_ids: list[UUID] session: AsyncSession,
range_spec: RangeSpec,
board_ids: list[UUID],
) -> DashboardRangeSeries: ) -> DashboardRangeSeries:
bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket") bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket")
in_progress = cast(Task.in_progress_at, DateTime) in_progress = cast(Task.in_progress_at, DateTime)
@@ -158,9 +172,14 @@ async def _query_cycle_time(
async def _query_error_rate( async def _query_error_rate(
session: AsyncSession, range_spec: RangeSpec, board_ids: list[UUID] session: AsyncSession,
range_spec: RangeSpec,
board_ids: list[UUID],
) -> DashboardRangeSeries: ) -> DashboardRangeSeries:
bucket_col = func.date_trunc(range_spec.bucket, ActivityEvent.created_at).label("bucket") bucket_col = func.date_trunc(
range_spec.bucket,
ActivityEvent.created_at,
).label("bucket")
error_case = case( error_case = case(
( (
col(ActivityEvent.event_type).like(ERROR_EVENT_PATTERN), col(ActivityEvent.event_type).like(ERROR_EVENT_PATTERN),
@@ -190,7 +209,9 @@ async def _query_error_rate(
async def _query_wip( async def _query_wip(
session: AsyncSession, range_spec: RangeSpec, board_ids: list[UUID] session: AsyncSession,
range_spec: RangeSpec,
board_ids: list[UUID],
) -> DashboardWipRangeSeries: ) -> DashboardWipRangeSeries:
bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket") bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("bucket")
inbox_case = case((col(Task.status) == "inbox", 1), else_=0) inbox_case = case((col(Task.status) == "inbox", 1), else_=0)
@@ -222,7 +243,10 @@ async def _query_wip(
return _wip_series_from_mapping(range_spec, mapping) return _wip_series_from_mapping(range_spec, mapping)
async def _median_cycle_time_7d(session: AsyncSession, board_ids: list[UUID]) -> float | None: async def _median_cycle_time_7d(
session: AsyncSession,
board_ids: list[UUID],
) -> float | None:
now = utcnow() now = utcnow()
start = now - timedelta(days=7) start = now - timedelta(days=7)
in_progress = cast(Task.in_progress_at, DateTime) in_progress = cast(Task.in_progress_at, DateTime)
@@ -248,7 +272,9 @@ async def _median_cycle_time_7d(session: AsyncSession, board_ids: list[UUID]) ->
async def _error_rate_kpi( async def _error_rate_kpi(
session: AsyncSession, range_spec: RangeSpec, board_ids: list[UUID] session: AsyncSession,
range_spec: RangeSpec,
board_ids: list[UUID],
) -> float: ) -> float:
error_case = case( error_case = case(
( (
@@ -302,12 +328,13 @@ async def _tasks_in_progress(session: AsyncSession, board_ids: list[UUID]) -> in
@router.get("/dashboard", response_model=DashboardMetrics) @router.get("/dashboard", response_model=DashboardMetrics)
async def dashboard_metrics( async def dashboard_metrics(
range: Literal["24h", "7d"] = Query(default="24h"), range_key: Literal["24h", "7d"] = RANGE_QUERY,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DashboardMetrics: ) -> DashboardMetrics:
primary = _resolve_range(range) """Return dashboard KPIs and time-series data for accessible boards."""
comparison = _comparison_range(range) primary = _resolve_range(range_key)
comparison = _comparison_range(range_key)
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)
throughput_primary = await _query_throughput(session, primary, board_ids) throughput_primary = await _query_throughput(session, primary, board_ids)

View File

@@ -1,27 +1,37 @@
"""Generic asynchronous CRUD helpers for SQLModel entities."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any, TypeVar
from typing import Any, TypeVar
from sqlalchemy import delete as sql_delete from sqlalchemy import delete as sql_delete
from sqlalchemy import update as sql_update from sqlalchemy import update as sql_update
from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlmodel import SQLModel, select from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
ModelT = TypeVar("ModelT", bound=SQLModel) ModelT = TypeVar("ModelT", bound=SQLModel)
class DoesNotExist(LookupError): class DoesNotExistError(LookupError):
pass """Raised when a query expected one row but found none."""
class MultipleObjectsReturned(LookupError): class MultipleObjectsReturnedError(LookupError):
pass """Raised when a query expected one row but found many."""
DoesNotExist = DoesNotExistError
MultipleObjectsReturned = MultipleObjectsReturnedError
async def _flush_or_rollback(session: AsyncSession) -> None: async def _flush_or_rollback(session: AsyncSession) -> None:
"""Flush changes and rollback on SQLAlchemy errors."""
try: try:
await session.flush() await session.flush()
except SQLAlchemyError: except SQLAlchemyError:
@@ -30,6 +40,7 @@ async def _flush_or_rollback(session: AsyncSession) -> None:
async def _commit_or_rollback(session: AsyncSession) -> None: async def _commit_or_rollback(session: AsyncSession) -> None:
"""Commit transaction and rollback on SQLAlchemy errors."""
try: try:
await session.commit() await session.commit()
except SQLAlchemyError: except SQLAlchemyError:
@@ -37,31 +48,50 @@ async def _commit_or_rollback(session: AsyncSession) -> None:
raise raise
def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]: def _lookup_statement(
model: type[ModelT],
lookup: Mapping[str, Any],
) -> SelectOfScalar[ModelT]:
"""Build a select statement with equality filters from lookup values."""
stmt = select(model) stmt = select(model)
for key, value in lookup.items(): for key, value in lookup.items():
stmt = stmt.where(getattr(model, key) == value) stmt = stmt.where(getattr(model, key) == value)
return stmt return stmt
async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None: async def get_by_id(
session: AsyncSession,
model: type[ModelT],
obj_id: object,
) -> ModelT | None:
"""Fetch one model instance by id or return None."""
stmt = _lookup_statement(model, {"id": obj_id}).limit(1) stmt = _lookup_statement(model, {"id": obj_id}).limit(1)
return (await session.exec(stmt)).first() return (await session.exec(stmt)).first()
async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT: async def get(
session: AsyncSession,
model: type[ModelT],
**lookup: object,
) -> ModelT:
"""Fetch exactly one model instance by lookup values."""
stmt = _lookup_statement(model, lookup).limit(2) stmt = _lookup_statement(model, lookup).limit(2)
items = (await session.exec(stmt)).all() items = (await session.exec(stmt)).all()
if not items: if not items:
raise DoesNotExist(f"{model.__name__} matching query does not exist.") message = f"{model.__name__} matching query does not exist."
raise DoesNotExist(message)
if len(items) > 1: if len(items) > 1:
raise MultipleObjectsReturned( message = f"Multiple {model.__name__} objects returned for lookup {lookup!r}."
f"Multiple {model.__name__} objects returned for lookup {lookup!r}." raise MultipleObjectsReturned(message)
)
return items[0] return items[0]
async def get_one_by(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT | None: async def get_one_by(
session: AsyncSession,
model: type[ModelT],
**lookup: object,
) -> ModelT | None:
"""Fetch the first model instance matching lookup values."""
stmt = _lookup_statement(model, lookup) stmt = _lookup_statement(model, lookup)
return (await session.exec(stmt)).first() return (await session.exec(stmt)).first()
@@ -72,8 +102,9 @@ async def create(
*, *,
commit: bool = True, commit: bool = True,
refresh: bool = True, refresh: bool = True,
**data: Any, **data: object,
) -> ModelT: ) -> ModelT:
"""Create, flush, optionally commit, and optionally refresh an object."""
obj = model.model_validate(data) obj = model.model_validate(data)
session.add(obj) session.add(obj)
await _flush_or_rollback(session) await _flush_or_rollback(session)
@@ -91,6 +122,7 @@ async def save(
commit: bool = True, commit: bool = True,
refresh: bool = True, refresh: bool = True,
) -> ModelT: ) -> ModelT:
"""Persist an existing object with optional commit and refresh."""
session.add(obj) session.add(obj)
await _flush_or_rollback(session) await _flush_or_rollback(session)
if commit: if commit:
@@ -101,6 +133,7 @@ async def save(
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None: async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None:
"""Delete an object with optional commit."""
await session.delete(obj) await session.delete(obj)
if commit: if commit:
await _commit_or_rollback(session) await _commit_or_rollback(session)
@@ -113,8 +146,9 @@ async def list_by(
order_by: Iterable[Any] = (), order_by: Iterable[Any] = (),
limit: int | None = None, limit: int | None = None,
offset: int | None = None, offset: int | None = None,
**lookup: Any, **lookup: object,
) -> list[ModelT]: ) -> list[ModelT]:
"""List objects by lookup values with optional ordering and pagination."""
stmt = _lookup_statement(model, lookup) stmt = _lookup_statement(model, lookup)
for ordering in order_by: for ordering in order_by:
stmt = stmt.order_by(ordering) stmt = stmt.order_by(ordering)
@@ -125,11 +159,19 @@ async def list_by(
return list(await session.exec(stmt)) return list(await session.exec(stmt))
async def exists(session: AsyncSession, model: type[ModelT], **lookup: Any) -> bool: async def exists(session: AsyncSession, model: type[ModelT], **lookup: object) -> bool:
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None """Return whether any object exists for lookup values."""
return (
(await session.exec(_lookup_statement(model, lookup).limit(1))).first()
is not None
)
def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> SelectOfScalar[ModelT]: def _criteria_statement(
model: type[ModelT],
criteria: tuple[Any, ...],
) -> SelectOfScalar[ModelT]:
"""Build a select statement from variadic where criteria."""
stmt = select(model) stmt = select(model)
if criteria: if criteria:
stmt = stmt.where(*criteria) stmt = stmt.where(*criteria)
@@ -139,9 +181,10 @@ def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> Selec
async def list_where( async def list_where(
session: AsyncSession, session: AsyncSession,
model: type[ModelT], model: type[ModelT],
*criteria: Any, *criteria: object,
order_by: Iterable[Any] = (), order_by: Iterable[Any] = (),
) -> list[ModelT]: ) -> list[ModelT]:
"""List objects filtered by explicit SQL criteria."""
stmt = _criteria_statement(model, criteria) stmt = _criteria_statement(model, criteria)
for ordering in order_by: for ordering in order_by:
stmt = stmt.order_by(ordering) stmt = stmt.order_by(ordering)
@@ -151,9 +194,10 @@ async def list_where(
async def delete_where( async def delete_where(
session: AsyncSession, session: AsyncSession,
model: type[ModelT], model: type[ModelT],
*criteria: Any, *criteria: object,
commit: bool = False, commit: bool = False,
) -> int: ) -> int:
"""Delete rows matching criteria and return affected row count."""
stmt: Any = sql_delete(model) stmt: Any = sql_delete(model)
if criteria: if criteria:
stmt = stmt.where(*criteria) stmt = stmt.where(*criteria)
@@ -167,18 +211,24 @@ async def delete_where(
async def update_where( async def update_where(
session: AsyncSession, session: AsyncSession,
model: type[ModelT], model: type[ModelT],
*criteria: Any, *criteria: object,
updates: Mapping[str, Any] | None = None, updates: Mapping[str, Any] | None = None,
commit: bool = False, **options: object,
exclude_none: bool = False,
allowed_fields: set[str] | None = None,
**update_fields: Any,
) -> int: ) -> int:
"""Apply bulk updates by criteria and return affected row count."""
commit = bool(options.pop("commit", False))
exclude_none = bool(options.pop("exclude_none", False))
allowed_fields_raw = options.pop("allowed_fields", None)
allowed_fields = (
allowed_fields_raw
if isinstance(allowed_fields_raw, set)
else None
)
source_updates: dict[str, Any] = {} source_updates: dict[str, Any] = {}
if updates: if updates:
source_updates.update(dict(updates)) source_updates.update(dict(updates))
if update_fields: if options:
source_updates.update(update_fields) source_updates.update(options)
values: dict[str, Any] = {} values: dict[str, Any] = {}
for key, value in source_updates.items(): for key, value in source_updates.items():
@@ -207,6 +257,7 @@ def apply_updates(
exclude_none: bool = False, exclude_none: bool = False,
allowed_fields: set[str] | None = None, allowed_fields: set[str] | None = None,
) -> ModelT: ) -> ModelT:
"""Apply a mapping of field updates onto an object."""
for key, value in updates.items(): for key, value in updates.items():
if allowed_fields is not None and key not in allowed_fields: if allowed_fields is not None and key not in allowed_fields:
continue continue
@@ -220,12 +271,18 @@ async def patch(
session: AsyncSession, session: AsyncSession,
obj: ModelT, obj: ModelT,
updates: Mapping[str, Any], updates: Mapping[str, Any],
*, **options: object,
exclude_none: bool = False,
allowed_fields: set[str] | None = None,
commit: bool = True,
refresh: bool = True,
) -> ModelT: ) -> ModelT:
"""Apply partial updates and persist object."""
exclude_none = bool(options.pop("exclude_none", False))
allowed_fields_raw = options.pop("allowed_fields", None)
allowed_fields = (
allowed_fields_raw
if isinstance(allowed_fields_raw, set)
else None
)
commit = bool(options.pop("commit", True))
refresh = bool(options.pop("refresh", True))
apply_updates( apply_updates(
obj, obj,
updates, updates,
@@ -242,8 +299,9 @@ async def get_or_create(
defaults: Mapping[str, Any] | None = None, defaults: Mapping[str, Any] | None = None,
commit: bool = True, commit: bool = True,
refresh: bool = True, refresh: bool = True,
**lookup: Any, **lookup: object,
) -> tuple[ModelT, bool]: ) -> tuple[ModelT, bool]:
"""Get one object by lookup, or create it with defaults."""
stmt = _lookup_statement(model, lookup) stmt = _lookup_statement(model, lookup)
existing = (await session.exec(stmt)).first() existing = (await session.exec(stmt)).first()

View File

@@ -1,3 +1,5 @@
"""OpenClaw gateway client helpers for websocket RPC calls."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@@ -14,16 +16,20 @@ from app.integrations.openclaw_gateway_protocol import PROTOCOL_VERSION
class OpenClawGatewayError(RuntimeError): class OpenClawGatewayError(RuntimeError):
pass """Raised when OpenClaw gateway calls fail."""
@dataclass @dataclass
class OpenClawResponse: class OpenClawResponse:
"""Container for raw OpenClaw payloads."""
payload: Any payload: Any
@dataclass(frozen=True) @dataclass(frozen=True)
class GatewayConfig: class GatewayConfig:
"""Connection configuration for the OpenClaw gateway."""
url: str url: str
token: str | None = None token: str | None = None
@@ -31,7 +37,8 @@ class GatewayConfig:
def _build_gateway_url(config: GatewayConfig) -> str: def _build_gateway_url(config: GatewayConfig) -> str:
base_url = (config.url or "").strip() base_url = (config.url or "").strip()
if not base_url: if not base_url:
raise OpenClawGatewayError("Gateway URL is not configured for this board.") message = "Gateway URL is not configured for this board."
raise OpenClawGatewayError(message)
token = config.token token = config.token
if not token: if not token:
return base_url return base_url
@@ -40,7 +47,10 @@ def _build_gateway_url(config: GatewayConfig) -> str:
return urlunparse(parsed._replace(query=query)) return urlunparse(parsed._replace(query=query))
async def _await_response(ws: websockets.WebSocketClientProtocol, request_id: str) -> Any: async def _await_response(
ws: websockets.WebSocketClientProtocol,
request_id: str,
) -> object:
while True: while True:
raw = await ws.recv() raw = await ws.recv()
data = json.loads(raw) data = json.loads(raw)
@@ -53,15 +63,23 @@ async def _await_response(ws: websockets.WebSocketClientProtocol, request_id: st
if data.get("id") == request_id: if data.get("id") == request_id:
if data.get("error"): if data.get("error"):
raise OpenClawGatewayError(data["error"].get("message", "Gateway error")) message = data["error"].get("message", "Gateway error")
raise OpenClawGatewayError(message)
return data.get("result") return data.get("result")
async def _send_request( async def _send_request(
ws: websockets.WebSocketClientProtocol, method: str, params: dict[str, Any] | None ws: websockets.WebSocketClientProtocol,
) -> Any: method: str,
params: dict[str, Any] | None,
) -> object:
request_id = str(uuid4()) request_id = str(uuid4())
message = {"type": "req", "id": request_id, "method": method, "params": params or {}} message = {
"type": "req",
"id": request_id,
"method": method,
"params": params or {},
}
await ws.send(json.dumps(message)) await ws.send(json.dumps(message))
return await _await_response(ws, request_id) return await _await_response(ws, request_id)
@@ -109,7 +127,8 @@ async def openclaw_call(
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
*, *,
config: GatewayConfig, config: GatewayConfig,
) -> Any: ) -> object:
"""Call a gateway RPC method and return the result payload."""
gateway_url = _build_gateway_url(config) gateway_url = _build_gateway_url(config)
try: try:
async with websockets.connect(gateway_url, ping_interval=None) as ws: async with websockets.connect(gateway_url, ping_interval=None) as ws:
@@ -138,7 +157,8 @@ async def send_message(
session_key: str, session_key: str,
config: GatewayConfig, config: GatewayConfig,
deliver: bool = False, deliver: bool = False,
) -> Any: ) -> object:
"""Send a chat message to a session."""
params: dict[str, Any] = { params: dict[str, Any] = {
"sessionKey": session_key, "sessionKey": session_key,
"message": message, "message": message,
@@ -152,14 +172,16 @@ async def get_chat_history(
session_key: str, session_key: str,
config: GatewayConfig, config: GatewayConfig,
limit: int | None = None, limit: int | None = None,
) -> Any: ) -> object:
"""Fetch chat history for a session."""
params: dict[str, Any] = {"sessionKey": session_key} params: dict[str, Any] = {"sessionKey": session_key}
if limit is not None: if limit is not None:
params["limit"] = limit params["limit"] = limit
return await openclaw_call("chat.history", params, config=config) return await openclaw_call("chat.history", params, config=config)
async def delete_session(session_key: str, *, config: GatewayConfig) -> Any: async def delete_session(session_key: str, *, config: GatewayConfig) -> object:
"""Delete a session by key."""
return await openclaw_call("sessions.delete", {"key": session_key}, config=config) return await openclaw_call("sessions.delete", {"key": session_key}, config=config)
@@ -168,7 +190,8 @@ async def ensure_session(
*, *,
config: GatewayConfig, config: GatewayConfig,
label: str | None = None, label: str | None = None,
) -> Any: ) -> object:
"""Ensure a session exists and optionally update its label."""
params: dict[str, Any] = {"key": session_key} params: dict[str, Any] = {"key": session_key}
if label: if label:
params["label"] = label params["label"] = label

View File

@@ -1,3 +1,5 @@
"""Pydantic/SQLModel schemas for agent API payloads."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
@@ -10,6 +12,8 @@ from sqlmodel import SQLModel
from app.schemas.common import NonEmptyStr from app.schemas.common import NonEmptyStr
_RUNTIME_TYPE_REFERENCES = (datetime, UUID, NonEmptyStr)
def _normalize_identity_profile( def _normalize_identity_profile(
profile: object, profile: object,
@@ -36,6 +40,8 @@ def _normalize_identity_profile(
class AgentBase(SQLModel): class AgentBase(SQLModel):
"""Common fields shared by agent create/read/update payloads."""
board_id: UUID | None = None board_id: UUID | None = None
name: NonEmptyStr name: NonEmptyStr
status: str = "provisioning" status: str = "provisioning"
@@ -46,7 +52,8 @@ class AgentBase(SQLModel):
@field_validator("identity_template", "soul_template", mode="before") @field_validator("identity_template", "soul_template", mode="before")
@classmethod @classmethod
def normalize_templates(cls, value: Any) -> Any: def normalize_templates(cls, value: object) -> object | None:
"""Normalize blank template text to null."""
if value is None: if value is None:
return None return None
if isinstance(value, str): if isinstance(value, str):
@@ -56,15 +63,21 @@ class AgentBase(SQLModel):
@field_validator("identity_profile", mode="before") @field_validator("identity_profile", mode="before")
@classmethod @classmethod
def normalize_identity_profile(cls, value: Any) -> Any: def normalize_identity_profile(
cls,
value: object,
) -> dict[str, str] | None:
"""Normalize identity-profile values into trimmed string mappings."""
return _normalize_identity_profile(value) return _normalize_identity_profile(value)
class AgentCreate(AgentBase): class AgentCreate(AgentBase):
pass """Payload for creating a new agent."""
class AgentUpdate(SQLModel): class AgentUpdate(SQLModel):
"""Payload for patching an existing agent."""
board_id: UUID | None = None board_id: UUID | None = None
is_gateway_main: bool | None = None is_gateway_main: bool | None = None
name: NonEmptyStr | None = None name: NonEmptyStr | None = None
@@ -76,7 +89,8 @@ class AgentUpdate(SQLModel):
@field_validator("identity_template", "soul_template", mode="before") @field_validator("identity_template", "soul_template", mode="before")
@classmethod @classmethod
def normalize_templates(cls, value: Any) -> Any: def normalize_templates(cls, value: object) -> object | None:
"""Normalize blank template text to null."""
if value is None: if value is None:
return None return None
if isinstance(value, str): if isinstance(value, str):
@@ -86,11 +100,17 @@ class AgentUpdate(SQLModel):
@field_validator("identity_profile", mode="before") @field_validator("identity_profile", mode="before")
@classmethod @classmethod
def normalize_identity_profile(cls, value: Any) -> Any: def normalize_identity_profile(
cls,
value: object,
) -> dict[str, str] | None:
"""Normalize identity-profile values into trimmed string mappings."""
return _normalize_identity_profile(value) return _normalize_identity_profile(value)
class AgentRead(AgentBase): class AgentRead(AgentBase):
"""Public agent representation returned by the API."""
id: UUID id: UUID
is_board_lead: bool = False is_board_lead: bool = False
is_gateway_main: bool = False is_gateway_main: bool = False
@@ -101,13 +121,19 @@ class AgentRead(AgentBase):
class AgentHeartbeat(SQLModel): class AgentHeartbeat(SQLModel):
"""Heartbeat status payload sent by agents."""
status: str | None = None status: str | None = None
class AgentHeartbeatCreate(AgentHeartbeat): class AgentHeartbeatCreate(AgentHeartbeat):
"""Heartbeat payload used to create an agent lazily."""
name: NonEmptyStr name: NonEmptyStr
board_id: UUID | None = None board_id: UUID | None = None
class AgentNudge(SQLModel): class AgentNudge(SQLModel):
"""Nudge message payload for pinging an agent."""
message: NonEmptyStr message: NonEmptyStr

View File

@@ -1,7 +1,9 @@
"""Schemas used by the board-onboarding assistant flow."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Any, Literal, Self from typing import Literal, Self
from uuid import UUID from uuid import UUID
from pydantic import Field, field_validator, model_validator from pydantic import Field, field_validator, model_validator
@@ -9,17 +11,23 @@ from sqlmodel import SQLModel
from app.schemas.common import NonEmptyStr from app.schemas.common import NonEmptyStr
_RUNTIME_TYPE_REFERENCES = (datetime, UUID, NonEmptyStr)
class BoardOnboardingStart(SQLModel): class BoardOnboardingStart(SQLModel):
pass """Start signal for initializing onboarding conversation."""
class BoardOnboardingAnswer(SQLModel): class BoardOnboardingAnswer(SQLModel):
"""User answer payload for a single onboarding question."""
answer: NonEmptyStr answer: NonEmptyStr
other_text: str | None = None other_text: str | None = None
class BoardOnboardingConfirm(SQLModel): class BoardOnboardingConfirm(SQLModel):
"""Payload used to confirm generated onboarding draft fields."""
board_type: str board_type: str
objective: str | None = None objective: str | None = None
success_metrics: dict[str, object] | None = None success_metrics: dict[str, object] | None = None
@@ -27,23 +35,32 @@ class BoardOnboardingConfirm(SQLModel):
@model_validator(mode="after") @model_validator(mode="after")
def validate_goal_fields(self) -> Self: def validate_goal_fields(self) -> Self:
if self.board_type == "goal": """Require goal metadata when the board type is `goal`."""
if not self.objective or not self.success_metrics: if self.board_type == "goal" and (
raise ValueError("Confirmed goal boards require objective and success_metrics") not self.objective or not self.success_metrics
):
message = (
"Confirmed goal boards require objective and success_metrics"
)
raise ValueError(message)
return self return self
class BoardOnboardingQuestionOption(SQLModel): class BoardOnboardingQuestionOption(SQLModel):
"""Selectable option for an onboarding question."""
id: NonEmptyStr id: NonEmptyStr
label: NonEmptyStr label: NonEmptyStr
class BoardOnboardingAgentQuestion(SQLModel): class BoardOnboardingAgentQuestion(SQLModel):
"""Question payload emitted by the onboarding assistant."""
question: NonEmptyStr question: NonEmptyStr
options: list[BoardOnboardingQuestionOption] = Field(min_length=1) options: list[BoardOnboardingQuestionOption] = Field(min_length=1)
def _normalize_optional_text(value: Any) -> Any: def _normalize_optional_text(value: object) -> object | None:
if value is None: if value is None:
return None return None
if isinstance(value, str): if isinstance(value, str):
@@ -53,6 +70,8 @@ def _normalize_optional_text(value: Any) -> Any:
class BoardOnboardingUserProfile(SQLModel): class BoardOnboardingUserProfile(SQLModel):
"""User-profile preferences gathered during onboarding."""
preferred_name: str | None = None preferred_name: str | None = None
pronouns: str | None = None pronouns: str | None = None
timezone: str | None = None timezone: str | None = None
@@ -68,7 +87,8 @@ class BoardOnboardingUserProfile(SQLModel):
mode="before", mode="before",
) )
@classmethod @classmethod
def normalize_text(cls, value: Any) -> Any: def normalize_text(cls, value: object) -> object | None:
"""Trim optional free-form profile text fields."""
return _normalize_optional_text(value) return _normalize_optional_text(value)
@@ -79,6 +99,8 @@ LeadAgentUpdateCadence = Literal["asap", "hourly", "daily", "weekly"]
class BoardOnboardingLeadAgentDraft(SQLModel): class BoardOnboardingLeadAgentDraft(SQLModel):
"""Editable lead-agent draft configuration."""
name: NonEmptyStr | None = None name: NonEmptyStr | None = None
# role, communication_style, emoji are expected keys. # role, communication_style, emoji are expected keys.
identity_profile: dict[str, str] | None = None identity_profile: dict[str, str] | None = None
@@ -97,12 +119,17 @@ class BoardOnboardingLeadAgentDraft(SQLModel):
mode="before", mode="before",
) )
@classmethod @classmethod
def normalize_text_fields(cls, value: Any) -> Any: def normalize_text_fields(cls, value: object) -> object | None:
"""Trim optional lead-agent preference fields."""
return _normalize_optional_text(value) return _normalize_optional_text(value)
@field_validator("identity_profile", mode="before") @field_validator("identity_profile", mode="before")
@classmethod @classmethod
def normalize_identity_profile(cls, value: Any) -> Any: def normalize_identity_profile(
cls,
value: object,
) -> object | None:
"""Normalize identity profile keys and values as trimmed strings."""
if value is None: if value is None:
return None return None
if not isinstance(value, dict): if not isinstance(value, dict):
@@ -121,6 +148,8 @@ class BoardOnboardingLeadAgentDraft(SQLModel):
class BoardOnboardingAgentComplete(BoardOnboardingConfirm): class BoardOnboardingAgentComplete(BoardOnboardingConfirm):
"""Complete onboarding draft produced by the onboarding assistant."""
status: Literal["complete"] status: Literal["complete"]
user_profile: BoardOnboardingUserProfile | None = None user_profile: BoardOnboardingUserProfile | None = None
lead_agent: BoardOnboardingLeadAgentDraft | None = None lead_agent: BoardOnboardingLeadAgentDraft | None = None
@@ -130,6 +159,8 @@ BoardOnboardingAgentUpdate = BoardOnboardingAgentComplete | BoardOnboardingAgent
class BoardOnboardingRead(SQLModel): class BoardOnboardingRead(SQLModel):
"""Stored onboarding session state returned by API endpoints."""
id: UUID id: UUID
board_id: UUID board_id: UUID
session_key: str session_key: str

View File

@@ -1,5 +1,4 @@
"""Gateway-facing agent provisioning and cleanup helpers.""" """Gateway-facing agent provisioning and cleanup helpers."""
# ruff: noqa: EM101, TRY003
from __future__ import annotations from __future__ import annotations
@@ -176,7 +175,8 @@ def _heartbeat_template_name(agent: Agent) -> str:
def _workspace_path(agent: Agent, workspace_root: str) -> str: def _workspace_path(agent: Agent, workspace_root: str) -> str:
if not workspace_root: if not workspace_root:
raise ValueError("gateway_workspace_root is required") msg = "gateway_workspace_root is required"
raise ValueError(msg)
root = workspace_root.rstrip("/") root = workspace_root.rstrip("/")
# Use agent key derived from session key when possible. This prevents collisions for # Use agent key derived from session key when possible. This prevents collisions for
# lead agents (session key includes board id) even if multiple boards share the same # lead agents (session key includes board id) even if multiple boards share the same
@@ -227,9 +227,11 @@ def _build_context(
user: User | None, user: User | None,
) -> dict[str, str]: ) -> dict[str, str]:
if not gateway.workspace_root: if not gateway.workspace_root:
raise ValueError("gateway_workspace_root is required") msg = "gateway_workspace_root is required"
raise ValueError(msg)
if not gateway.main_session_key: if not gateway.main_session_key:
raise ValueError("gateway_main_session_key is required") msg = "gateway_main_session_key is required"
raise ValueError(msg)
agent_id = str(agent.id) agent_id = str(agent.id)
workspace_root = gateway.workspace_root workspace_root = gateway.workspace_root
workspace_path = _workspace_path(agent, workspace_root) workspace_path = _workspace_path(agent, workspace_root)
@@ -485,15 +487,18 @@ async def _patch_gateway_agent_list(
) -> None: ) -> None:
cfg = await openclaw_call("config.get", config=config) cfg = await openclaw_call("config.get", config=config)
if not isinstance(cfg, dict): if not isinstance(cfg, dict):
raise OpenClawGatewayError("config.get returned invalid payload") msg = "config.get returned invalid payload"
raise OpenClawGatewayError(msg)
base_hash = cfg.get("hash") base_hash = cfg.get("hash")
data = cfg.get("config") or cfg.get("parsed") or {} data = cfg.get("config") or cfg.get("parsed") or {}
if not isinstance(data, dict): if not isinstance(data, dict):
raise OpenClawGatewayError("config.get returned invalid config") msg = "config.get returned invalid config"
raise OpenClawGatewayError(msg)
agents = data.get("agents") or {} agents = data.get("agents") or {}
lst = agents.get("list") or [] lst = agents.get("list") or []
if not isinstance(lst, list): if not isinstance(lst, list):
raise OpenClawGatewayError("config agents.list is not a list") msg = "config agents.list is not a list"
raise OpenClawGatewayError(msg)
updated = False updated = False
new_list: list[dict[str, Any]] = [] new_list: list[dict[str, Any]] = []
@@ -528,19 +533,23 @@ async def patch_gateway_agent_heartbeats( # noqa: C901
Each entry is (agent_id, workspace_path, heartbeat_dict). Each entry is (agent_id, workspace_path, heartbeat_dict).
""" """
if not gateway.url: if not gateway.url:
raise OpenClawGatewayError("Gateway url is required") msg = "Gateway url is required"
raise OpenClawGatewayError(msg)
config = GatewayClientConfig(url=gateway.url, token=gateway.token) config = GatewayClientConfig(url=gateway.url, token=gateway.token)
cfg = await openclaw_call("config.get", config=config) cfg = await openclaw_call("config.get", config=config)
if not isinstance(cfg, dict): if not isinstance(cfg, dict):
raise OpenClawGatewayError("config.get returned invalid payload") msg = "config.get returned invalid payload"
raise OpenClawGatewayError(msg)
base_hash = cfg.get("hash") base_hash = cfg.get("hash")
data = cfg.get("config") or cfg.get("parsed") or {} data = cfg.get("config") or cfg.get("parsed") or {}
if not isinstance(data, dict): if not isinstance(data, dict):
raise OpenClawGatewayError("config.get returned invalid config") msg = "config.get returned invalid config"
raise OpenClawGatewayError(msg)
agents_section = data.get("agents") or {} agents_section = data.get("agents") or {}
lst = agents_section.get("list") or [] lst = agents_section.get("list") or []
if not isinstance(lst, list): if not isinstance(lst, list):
raise OpenClawGatewayError("config agents.list is not a list") msg = "config agents.list is not a list"
raise OpenClawGatewayError(msg)
entry_by_id: dict[str, tuple[str, dict[str, Any]]] = { entry_by_id: dict[str, tuple[str, dict[str, Any]]] = {
agent_id: (workspace_path, heartbeat) agent_id: (workspace_path, heartbeat)
@@ -581,7 +590,8 @@ async def patch_gateway_agent_heartbeats( # noqa: C901
async def sync_gateway_agent_heartbeats(gateway: Gateway, agents: list[Agent]) -> None: async def sync_gateway_agent_heartbeats(gateway: Gateway, agents: list[Agent]) -> None:
"""Sync current Agent.heartbeat_config values to the gateway config.""" """Sync current Agent.heartbeat_config values to the gateway config."""
if not gateway.workspace_root: if not gateway.workspace_root:
raise OpenClawGatewayError("gateway workspace_root is required") msg = "gateway workspace_root is required"
raise OpenClawGatewayError(msg)
entries: list[tuple[str, str, dict[str, Any]]] = [] entries: list[tuple[str, str, dict[str, Any]]] = []
for agent in agents: for agent in agents:
agent_id = _agent_key(agent) agent_id = _agent_key(agent)
@@ -599,15 +609,18 @@ async def _remove_gateway_agent_list(
) -> None: ) -> None:
cfg = await openclaw_call("config.get", config=config) cfg = await openclaw_call("config.get", config=config)
if not isinstance(cfg, dict): if not isinstance(cfg, dict):
raise OpenClawGatewayError("config.get returned invalid payload") msg = "config.get returned invalid payload"
raise OpenClawGatewayError(msg)
base_hash = cfg.get("hash") base_hash = cfg.get("hash")
data = cfg.get("config") or cfg.get("parsed") or {} data = cfg.get("config") or cfg.get("parsed") or {}
if not isinstance(data, dict): if not isinstance(data, dict):
raise OpenClawGatewayError("config.get returned invalid config") msg = "config.get returned invalid config"
raise OpenClawGatewayError(msg)
agents = data.get("agents") or {} agents = data.get("agents") or {}
lst = agents.get("list") or [] lst = agents.get("list") or []
if not isinstance(lst, list): if not isinstance(lst, list):
raise OpenClawGatewayError("config agents.list is not a list") msg = "config agents.list is not a list"
raise OpenClawGatewayError(msg)
new_list = [ new_list = [
entry entry
@@ -658,7 +671,8 @@ async def provision_agent( # noqa: C901, PLR0912, PLR0913
if not gateway.url: if not gateway.url:
return return
if not gateway.workspace_root: if not gateway.workspace_root:
raise ValueError("gateway_workspace_root is required") msg = "gateway_workspace_root is required"
raise ValueError(msg)
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
session_key = _session_key(agent) session_key = _session_key(agent)
await ensure_session(session_key, config=client_config, label=agent.name) await ensure_session(session_key, config=client_config, label=agent.name)
@@ -734,7 +748,8 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
if not gateway.url: if not gateway.url:
return return
if not gateway.main_session_key: if not gateway.main_session_key:
raise ValueError("gateway main_session_key is required") msg = "gateway main_session_key is required"
raise ValueError(msg)
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
await ensure_session( await ensure_session(
gateway.main_session_key, config=client_config, label="Main Agent", gateway.main_session_key, config=client_config, label="Main Agent",
@@ -745,7 +760,8 @@ async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
fallback_session_key=gateway.main_session_key, fallback_session_key=gateway.main_session_key,
) )
if not agent_id: if not agent_id:
raise OpenClawGatewayError("Unable to resolve gateway main agent id") msg = "Unable to resolve gateway main agent id"
raise OpenClawGatewayError(msg)
context = _build_main_context(agent, gateway, auth_token, user) context = _build_main_context(agent, gateway, auth_token, user)
supported = set(await _supported_gateway_files(client_config)) supported = set(await _supported_gateway_files(client_config))
@@ -796,7 +812,8 @@ async def cleanup_agent(
if not gateway.url: if not gateway.url:
return None return None
if not gateway.workspace_root: if not gateway.workspace_root:
raise ValueError("gateway_workspace_root is required") msg = "gateway_workspace_root is required"
raise ValueError(msg)
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
agent_id = _agent_key(agent) agent_id = _agent_key(agent)

View File

@@ -1,7 +1,8 @@
"""Helpers for assembling board-group snapshot view models."""
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from typing import Any
from uuid import UUID from uuid import UUID
from sqlalchemy import case, func from sqlalchemy import case, func
@@ -22,48 +23,67 @@ from app.schemas.view_models import (
_STATUS_ORDER = {"in_progress": 0, "review": 1, "inbox": 2, "done": 3} _STATUS_ORDER = {"in_progress": 0, "review": 1, "inbox": 2, "done": 3}
_PRIORITY_ORDER = {"high": 0, "medium": 1, "low": 2} _PRIORITY_ORDER = {"high": 0, "medium": 1, "low": 2}
_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession)
def _status_weight_expr() -> Any: def _status_weight_expr() -> object:
"""Return a SQL expression that sorts task statuses by configured order."""
whens = [(col(Task.status) == key, weight) for key, weight in _STATUS_ORDER.items()] whens = [(col(Task.status) == key, weight) for key, weight in _STATUS_ORDER.items()]
return case(*whens, else_=99) return case(*whens, else_=99)
def _priority_weight_expr() -> Any: def _priority_weight_expr() -> object:
whens = [(col(Task.priority) == key, weight) for key, weight in _PRIORITY_ORDER.items()] """Return a SQL expression that sorts task priorities by configured order."""
whens = [
(col(Task.priority) == key, weight)
for key, weight in _PRIORITY_ORDER.items()
]
return case(*whens, else_=99) return case(*whens, else_=99)
async def build_group_snapshot( async def _boards_for_group(
session: AsyncSession, session: AsyncSession,
*, *,
group: BoardGroup, group_id: UUID,
exclude_board_id: UUID | None = None, exclude_board_id: UUID | None = None,
include_done: bool = False, ) -> list[Board]:
per_board_task_limit: int = 5, """Return boards belonging to a board group with optional exclusion."""
) -> BoardGroupSnapshot: statement = Board.objects.filter_by(board_group_id=group_id).statement
statement = Board.objects.filter_by(board_group_id=group.id).statement
if exclude_board_id is not None: if exclude_board_id is not None:
statement = statement.where(col(Board.id) != exclude_board_id) statement = statement.where(col(Board.id) != exclude_board_id)
boards = list(await session.exec(statement.order_by(func.lower(col(Board.name)).asc()))) return list(
if not boards: await session.exec(
return BoardGroupSnapshot(group=BoardGroupRead.model_validate(group, from_attributes=True)) statement.order_by(func.lower(col(Board.name)).asc()),
),
)
boards_by_id = {board.id: board for board in boards}
board_ids = list(boards_by_id.keys())
async def _task_counts_by_board(
session: AsyncSession,
board_ids: list[UUID],
) -> dict[UUID, dict[str, int]]:
"""Return per-board task counts keyed by task status."""
task_counts: dict[UUID, dict[str, int]] = defaultdict(lambda: defaultdict(int)) task_counts: dict[UUID, dict[str, int]] = defaultdict(lambda: defaultdict(int))
for board_id, status_value, total in list( for board_id, status_value, total in list(
await session.exec( await session.exec(
select(col(Task.board_id), col(Task.status), func.count(col(Task.id))) select(col(Task.board_id), col(Task.status), func.count(col(Task.id)))
.where(col(Task.board_id).in_(board_ids)) .where(col(Task.board_id).in_(board_ids))
.group_by(col(Task.board_id), col(Task.status)) .group_by(col(Task.board_id), col(Task.status)),
) ),
): ):
if board_id is None: if board_id is None:
continue continue
task_counts[board_id][str(status_value)] = int(total or 0) task_counts[board_id][str(status_value)] = int(total or 0)
return task_counts
async def _ordered_tasks_for_boards(
session: AsyncSession,
board_ids: list[UUID],
*,
include_done: bool,
) -> list[Task]:
"""Return sorted tasks for boards, optionally excluding completed tasks."""
task_statement = select(Task).where(col(Task.board_id).in_(board_ids)) task_statement = select(Task).where(col(Task.board_id).in_(board_ids))
if not include_done: if not include_done:
task_statement = task_statement.where(col(Task.status) != "done") task_statement = task_statement.where(col(Task.status) != "done")
@@ -74,20 +94,43 @@ async def build_group_snapshot(
col(Task.updated_at).desc(), col(Task.updated_at).desc(),
col(Task.created_at).desc(), col(Task.created_at).desc(),
) )
tasks = list(await session.exec(task_statement)) return list(await session.exec(task_statement))
assigned_ids = {task.assigned_agent_id for task in tasks if task.assigned_agent_id is not None}
agent_name_by_id: dict[UUID, str] = {} async def _agent_names(
if assigned_ids: session: AsyncSession,
for agent_id, name in list( tasks: list[Task],
) -> dict[UUID, str]:
"""Return agent names keyed by assigned agent ids in the provided tasks."""
assigned_ids = {
task.assigned_agent_id
for task in tasks
if task.assigned_agent_id is not None
}
if not assigned_ids:
return {}
return dict(
list(
await session.exec( await session.exec(
select(col(Agent.id), col(Agent.name)).where(col(Agent.id).in_(assigned_ids)) select(col(Agent.id), col(Agent.name)).where(
col(Agent.id).in_(assigned_ids),
),
),
),
) )
):
agent_name_by_id[agent_id] = name
def _task_summaries_by_board(
*,
boards_by_id: dict[UUID, Board],
tasks: list[Task],
agent_name_by_id: dict[UUID, str],
per_board_task_limit: int,
) -> dict[UUID, list[BoardGroupTaskSummary]]:
"""Build limited per-board task summary lists."""
tasks_by_board: dict[UUID, list[BoardGroupTaskSummary]] = defaultdict(list) tasks_by_board: dict[UUID, list[BoardGroupTaskSummary]] = defaultdict(list)
if per_board_task_limit > 0: if per_board_task_limit <= 0:
return tasks_by_board
for task in tasks: for task in tasks:
if task.board_id is None: if task.board_id is None:
continue continue
@@ -115,21 +158,52 @@ async def build_group_snapshot(
in_progress_at=task.in_progress_at, in_progress_at=task.in_progress_at,
created_at=task.created_at, created_at=task.created_at,
updated_at=task.updated_at, updated_at=task.updated_at,
),
) )
) return tasks_by_board
snapshots: list[BoardGroupBoardSnapshot] = []
for board in boards: async def build_group_snapshot(
board_read = BoardRead.model_validate(board, from_attributes=True) session: AsyncSession,
counts = dict(task_counts.get(board.id, {})) *,
snapshots.append( group: BoardGroup,
exclude_board_id: UUID | None = None,
include_done: bool = False,
per_board_task_limit: int = 5,
) -> BoardGroupSnapshot:
"""Build a board-group snapshot with board/task summaries."""
boards = await _boards_for_group(
session,
group_id=group.id,
exclude_board_id=exclude_board_id,
)
if not boards:
return BoardGroupSnapshot(
group=BoardGroupRead.model_validate(group, from_attributes=True),
)
boards_by_id = {board.id: board for board in boards}
board_ids = list(boards_by_id.keys())
task_counts = await _task_counts_by_board(session, board_ids)
tasks = await _ordered_tasks_for_boards(
session,
board_ids,
include_done=include_done,
)
agent_name_by_id = await _agent_names(session, tasks)
tasks_by_board = _task_summaries_by_board(
boards_by_id=boards_by_id,
tasks=tasks,
agent_name_by_id=agent_name_by_id,
per_board_task_limit=per_board_task_limit,
)
snapshots = [
BoardGroupBoardSnapshot( BoardGroupBoardSnapshot(
board=board_read, board=BoardRead.model_validate(board, from_attributes=True),
task_counts=counts, task_counts=dict(task_counts.get(board.id, {})),
tasks=tasks_by_board.get(board.id, []), tasks=tasks_by_board.get(board.id, []),
) )
) for board in boards
]
return BoardGroupSnapshot( return BoardGroupSnapshot(
group=BoardGroupRead.model_validate(group, from_attributes=True), group=BoardGroupRead.model_validate(group, from_attributes=True),
boards=snapshots, boards=snapshots,
@@ -144,6 +218,7 @@ async def build_board_group_snapshot(
include_done: bool = False, include_done: bool = False,
per_board_task_limit: int = 5, per_board_task_limit: int = 5,
) -> BoardGroupSnapshot: ) -> BoardGroupSnapshot:
"""Build a board-group snapshot anchored to a board context."""
if not board.board_group_id: if not board.board_group_id:
return BoardGroupSnapshot(group=None, boards=[]) return BoardGroupSnapshot(group=None, boards=[])
group = await BoardGroup.objects.by_id(board.board_group_id).first(session) group = await BoardGroup.objects.by_id(board.board_group_id).first(session)

View File

@@ -1,5 +1,4 @@
"""Organization membership and board-access service helpers.""" """Organization membership and board-access service helpers."""
# ruff: noqa: D101, D103
from __future__ import annotations from __future__ import annotations
@@ -38,19 +37,24 @@ ROLE_RANK = {"member": 0, "admin": 1, "owner": 2}
@dataclass(frozen=True) @dataclass(frozen=True)
class OrganizationContext: class OrganizationContext:
"""Resolved organization and membership for the active user."""
organization: Organization organization: Organization
member: OrganizationMember member: OrganizationMember
def is_org_admin(member: OrganizationMember) -> bool: def is_org_admin(member: OrganizationMember) -> bool:
"""Return whether a member has admin-level organization privileges."""
return member.role in ADMIN_ROLES return member.role in ADMIN_ROLES
async def get_default_org(session: AsyncSession) -> Organization | None: async def get_default_org(session: AsyncSession) -> Organization | None:
"""Return the default personal organization if it exists."""
return await Organization.objects.filter_by(name=DEFAULT_ORG_NAME).first(session) return await Organization.objects.filter_by(name=DEFAULT_ORG_NAME).first(session)
async def ensure_default_org(session: AsyncSession) -> Organization: async def ensure_default_org(session: AsyncSession) -> Organization:
"""Ensure and return the default personal organization."""
org = await get_default_org(session) org = await get_default_org(session)
if org is not None: if org is not None:
return org return org
@@ -67,6 +71,7 @@ async def get_member(
user_id: UUID, user_id: UUID,
organization_id: UUID, organization_id: UUID,
) -> OrganizationMember | None: ) -> OrganizationMember | None:
"""Fetch a membership by user id and organization id."""
return await OrganizationMember.objects.filter_by( return await OrganizationMember.objects.filter_by(
user_id=user_id, user_id=user_id,
organization_id=organization_id, organization_id=organization_id,
@@ -76,6 +81,7 @@ async def get_member(
async def get_first_membership( async def get_first_membership(
session: AsyncSession, user_id: UUID, session: AsyncSession, user_id: UUID,
) -> OrganizationMember | None: ) -> OrganizationMember | None:
"""Return the oldest membership for a user, if any."""
return ( return (
await OrganizationMember.objects.filter_by(user_id=user_id) await OrganizationMember.objects.filter_by(user_id=user_id)
.order_by(col(OrganizationMember.created_at).asc()) .order_by(col(OrganizationMember.created_at).asc())
@@ -89,6 +95,7 @@ async def set_active_organization(
user: User, user: User,
organization_id: UUID, organization_id: UUID,
) -> OrganizationMember: ) -> OrganizationMember:
"""Set a user's active organization and return the membership."""
member = await get_member(session, user_id=user.id, organization_id=organization_id) member = await get_member(session, user_id=user.id, organization_id=organization_id)
if member is None: if member is None:
raise HTTPException( raise HTTPException(
@@ -105,6 +112,7 @@ async def get_active_membership(
session: AsyncSession, session: AsyncSession,
user: User, user: User,
) -> OrganizationMember | None: ) -> OrganizationMember | None:
"""Resolve and normalize the user's currently active membership."""
db_user = await User.objects.by_id(user.id).first(session) db_user = await User.objects.by_id(user.id).first(session)
if db_user is None: if db_user is None:
db_user = user db_user = user
@@ -151,6 +159,7 @@ async def accept_invite(
invite: OrganizationInvite, invite: OrganizationInvite,
user: User, user: User,
) -> OrganizationMember: ) -> OrganizationMember:
"""Accept an invite and create membership plus scoped board access rows."""
now = utcnow() now = utcnow()
member = OrganizationMember( member = OrganizationMember(
organization_id=invite.organization_id, organization_id=invite.organization_id,
@@ -200,6 +209,7 @@ async def accept_invite(
async def ensure_member_for_user( async def ensure_member_for_user(
session: AsyncSession, user: User, session: AsyncSession, user: User,
) -> OrganizationMember: ) -> OrganizationMember:
"""Ensure a user has some membership, creating one if necessary."""
existing = await get_active_membership(session, user) existing = await get_active_membership(session, user)
if existing is not None: if existing is not None:
return existing return existing
@@ -237,10 +247,12 @@ async def ensure_member_for_user(
def member_all_boards_read(member: OrganizationMember) -> bool: def member_all_boards_read(member: OrganizationMember) -> bool:
"""Return whether the member has organization-wide read access."""
return member.all_boards_read or member.all_boards_write return member.all_boards_read or member.all_boards_write
def member_all_boards_write(member: OrganizationMember) -> bool: def member_all_boards_write(member: OrganizationMember) -> bool:
"""Return whether the member has organization-wide write access."""
return member.all_boards_write return member.all_boards_write
@@ -251,6 +263,7 @@ async def has_board_access(
board: Board, board: Board,
write: bool, write: bool,
) -> bool: ) -> bool:
"""Return whether a member has board access for the requested mode."""
if member.organization_id != board.organization_id: if member.organization_id != board.organization_id:
return False return False
if write: if write:
@@ -276,6 +289,7 @@ async def require_board_access(
board: Board, board: Board,
write: bool, write: bool,
) -> OrganizationMember: ) -> OrganizationMember:
"""Require board access for a user and return matching membership."""
member = await get_member( member = await get_member(
session, user_id=user.id, organization_id=board.organization_id, session, user_id=user.id, organization_id=board.organization_id,
) )
@@ -293,6 +307,7 @@ async def require_board_access(
def board_access_filter( def board_access_filter(
member: OrganizationMember, *, write: bool, member: OrganizationMember, *, write: bool,
) -> ColumnElement[bool]: ) -> ColumnElement[bool]:
"""Build a SQL filter expression for boards visible to a member."""
if write and member_all_boards_write(member): if write and member_all_boards_write(member):
return col(Board.organization_id) == member.organization_id return col(Board.organization_id) == member.organization_id
if not write and member_all_boards_read(member): if not write and member_all_boards_read(member):
@@ -320,6 +335,7 @@ async def list_accessible_board_ids(
member: OrganizationMember, member: OrganizationMember,
write: bool, write: bool,
) -> list[UUID]: ) -> list[UUID]:
"""List board ids accessible to a member for read or write mode."""
if (write and member_all_boards_write(member)) or ( if (write and member_all_boards_write(member)) or (
not write and member_all_boards_read(member) not write and member_all_boards_read(member)
): ):
@@ -354,6 +370,7 @@ async def apply_member_access_update(
member: OrganizationMember, member: OrganizationMember,
update: OrganizationMemberAccessUpdate, update: OrganizationMemberAccessUpdate,
) -> None: ) -> None:
"""Replace explicit member board-access rows from an access update."""
now = utcnow() now = utcnow()
member.all_boards_read = update.all_boards_read member.all_boards_read = update.all_boards_read
member.all_boards_write = update.all_boards_write member.all_boards_write = update.all_boards_write
@@ -390,6 +407,7 @@ async def apply_invite_board_access(
invite: OrganizationInvite, invite: OrganizationInvite,
entries: Iterable[OrganizationBoardAccessSpec], entries: Iterable[OrganizationBoardAccessSpec],
) -> None: ) -> None:
"""Replace explicit invite board-access rows for an invite."""
await crud.delete_where( await crud.delete_where(
session, session,
OrganizationInviteBoardAccess, OrganizationInviteBoardAccess,
@@ -414,10 +432,12 @@ async def apply_invite_board_access(
def normalize_invited_email(email: str) -> str: def normalize_invited_email(email: str) -> str:
"""Normalize an invited email address for storage/comparison."""
return email.strip().lower() return email.strip().lower()
def normalize_role(role: str) -> str: def normalize_role(role: str) -> str:
"""Normalize a role string and default empty values to `member`."""
return role.strip().lower() or "member" return role.strip().lower() or "member"
@@ -433,6 +453,7 @@ async def apply_invite_to_member(
member: OrganizationMember, member: OrganizationMember,
invite: OrganizationInvite, invite: OrganizationInvite,
) -> None: ) -> None:
"""Apply invite role/access grants onto an existing organization member."""
now = utcnow() now = utcnow()
member_changed = False member_changed = False
invite_role = normalize_role(invite.role or "member") invite_role = normalize_role(invite.role or "member")

View File

@@ -1,3 +1,5 @@
"""Task-dependency helpers for validation, querying, and replacement."""
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
@@ -14,6 +16,7 @@ from app.models.task_dependencies import TaskDependency
from app.models.tasks import Task from app.models.tasks import Task
DONE_STATUS: Final[str] = "done" DONE_STATUS: Final[str] = "done"
_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession, Mapping, Sequence)
def _dedupe_uuid_list(values: Sequence[UUID]) -> list[UUID]: def _dedupe_uuid_list(values: Sequence[UUID]) -> list[UUID]:
@@ -34,6 +37,7 @@ async def dependency_ids_by_task_id(
board_id: UUID, board_id: UUID,
task_ids: Sequence[UUID], task_ids: Sequence[UUID],
) -> dict[UUID, list[UUID]]: ) -> dict[UUID, list[UUID]]:
"""Return dependency ids keyed by task id for tasks on a board."""
if not task_ids: if not task_ids:
return {} return {}
rows = list( rows = list(
@@ -41,8 +45,8 @@ async def dependency_ids_by_task_id(
select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id)) select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id))
.where(col(TaskDependency.board_id) == board_id) .where(col(TaskDependency.board_id) == board_id)
.where(col(TaskDependency.task_id).in_(task_ids)) .where(col(TaskDependency.task_id).in_(task_ids))
.order_by(col(TaskDependency.created_at).asc()) .order_by(col(TaskDependency.created_at).asc()),
) ),
) )
mapping: dict[UUID, list[UUID]] = defaultdict(list) mapping: dict[UUID, list[UUID]] = defaultdict(list)
for task_id, depends_on_task_id in rows: for task_id, depends_on_task_id in rows:
@@ -56,16 +60,17 @@ async def dependency_status_by_id(
board_id: UUID, board_id: UUID,
dependency_ids: Sequence[UUID], dependency_ids: Sequence[UUID],
) -> dict[UUID, str]: ) -> dict[UUID, str]:
"""Return dependency status values keyed by dependency task id."""
if not dependency_ids: if not dependency_ids:
return {} return {}
rows = list( rows = list(
await session.exec( await session.exec(
select(col(Task.id), col(Task.status)) select(col(Task.id), col(Task.status))
.where(col(Task.board_id) == board_id) .where(col(Task.board_id) == board_id)
.where(col(Task.id).in_(dependency_ids)) .where(col(Task.id).in_(dependency_ids)),
),
) )
) return dict(rows)
return {task_id: status_value for task_id, status_value in rows}
def blocked_by_dependency_ids( def blocked_by_dependency_ids(
@@ -73,11 +78,12 @@ def blocked_by_dependency_ids(
dependency_ids: Sequence[UUID], dependency_ids: Sequence[UUID],
status_by_id: Mapping[UUID, str], status_by_id: Mapping[UUID, str],
) -> list[UUID]: ) -> list[UUID]:
blocked: list[UUID] = [] """Return dependency ids that are not yet in the done status."""
for dep_id in dependency_ids: return [
if status_by_id.get(dep_id) != DONE_STATUS: dep_id
blocked.append(dep_id) for dep_id in dependency_ids
return blocked if status_by_id.get(dep_id) != DONE_STATUS
]
async def blocked_by_for_task( async def blocked_by_for_task(
@@ -87,6 +93,7 @@ async def blocked_by_for_task(
task_id: UUID, task_id: UUID,
dependency_ids: Sequence[UUID] | None = None, dependency_ids: Sequence[UUID] | None = None,
) -> list[UUID]: ) -> list[UUID]:
"""Return unresolved dependency ids for the provided task."""
dep_ids = list(dependency_ids or []) dep_ids = list(dependency_ids or [])
if dependency_ids is None: if dependency_ids is None:
deps_map = await dependency_ids_by_task_id( deps_map = await dependency_ids_by_task_id(
@@ -97,11 +104,16 @@ async def blocked_by_for_task(
dep_ids = deps_map.get(task_id, []) dep_ids = deps_map.get(task_id, [])
if not dep_ids: if not dep_ids:
return [] return []
status_by_id = await dependency_status_by_id(session, board_id=board_id, dependency_ids=dep_ids) status_by_id = await dependency_status_by_id(
session,
board_id=board_id,
dependency_ids=dep_ids,
)
return blocked_by_dependency_ids(dependency_ids=dep_ids, status_by_id=status_by_id) return blocked_by_dependency_ids(dependency_ids=dep_ids, status_by_id=status_by_id)
def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool: def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool:
"""Detect cycles in a directed dependency graph."""
visited: set[UUID] = set() visited: set[UUID] = set()
in_stack: set[UUID] = set() in_stack: set[UUID] = set()
@@ -118,10 +130,7 @@ def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool:
in_stack.remove(node) in_stack.remove(node)
return False return False
for node in nodes: return any(dfs(node) for node in nodes)
if dfs(node):
return True
return False
async def validate_dependency_update( async def validate_dependency_update(
@@ -131,6 +140,7 @@ async def validate_dependency_update(
task_id: UUID, task_id: UUID,
depends_on_task_ids: Sequence[UUID], depends_on_task_ids: Sequence[UUID],
) -> list[UUID]: ) -> list[UUID]:
"""Validate a dependency update and return normalized dependency ids."""
normalized = _dedupe_uuid_list(depends_on_task_ids) normalized = _dedupe_uuid_list(depends_on_task_ids)
if task_id in normalized: if task_id in normalized:
raise HTTPException( raise HTTPException(
@@ -145,8 +155,8 @@ async def validate_dependency_update(
await session.exec( await session.exec(
select(col(Task.id)) select(col(Task.id))
.where(col(Task.board_id) == board_id) .where(col(Task.board_id) == board_id)
.where(col(Task.id).in_(normalized)) .where(col(Task.id).in_(normalized)),
) ),
) )
missing = [dep_id for dep_id in normalized if dep_id not in existing_ids] missing = [dep_id for dep_id in normalized if dep_id not in existing_ids]
if missing: if missing:
@@ -159,13 +169,18 @@ async def validate_dependency_update(
) )
# Ensure the dependency graph is acyclic after applying the update. # Ensure the dependency graph is acyclic after applying the update.
task_ids = list(await session.exec(select(col(Task.id)).where(col(Task.board_id) == board_id))) task_ids = list(
await session.exec(
select(col(Task.id)).where(col(Task.board_id) == board_id),
),
)
rows = list( rows = list(
await session.exec( await session.exec(
select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id)).where( select(
col(TaskDependency.board_id) == board_id col(TaskDependency.task_id),
) col(TaskDependency.depends_on_task_id),
) ).where(col(TaskDependency.board_id) == board_id),
),
) )
edges: dict[UUID, set[UUID]] = defaultdict(set) edges: dict[UUID, set[UUID]] = defaultdict(set)
for src, dst in rows: for src, dst in rows:
@@ -188,6 +203,7 @@ async def replace_task_dependencies(
task_id: UUID, task_id: UUID,
depends_on_task_ids: Sequence[UUID], depends_on_task_ids: Sequence[UUID],
) -> list[UUID]: ) -> list[UUID]:
"""Replace dependencies for a task and return the normalized dependency ids."""
normalized = await validate_dependency_update( normalized = await validate_dependency_update(
session, session,
board_id=board_id, board_id=board_id,
@@ -207,7 +223,7 @@ async def replace_task_dependencies(
board_id=board_id, board_id=board_id,
task_id=task_id, task_id=task_id,
depends_on_task_id=dep_id, depends_on_task_id=dep_id,
) ),
) )
return normalized return normalized
@@ -218,9 +234,10 @@ async def dependent_task_ids(
board_id: UUID, board_id: UUID,
dependency_task_id: UUID, dependency_task_id: UUID,
) -> list[UUID]: ) -> list[UUID]:
"""Return task ids that depend on the provided dependency task id."""
rows = await session.exec( rows = await session.exec(
select(col(TaskDependency.task_id)) select(col(TaskDependency.task_id))
.where(col(TaskDependency.board_id) == board_id) .where(col(TaskDependency.board_id) == board_id)
.where(col(TaskDependency.depends_on_task_id) == dependency_task_id) .where(col(TaskDependency.depends_on_task_id) == dependency_task_id),
) )
return list(rows) return list(rows)

View File

@@ -1,9 +1,12 @@
"""Gateway template synchronization orchestration."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import random import random
import re import re
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import TypeVar from typing import TypeVar
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@@ -11,7 +14,11 @@ from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_tokens import generate_agent_token, hash_agent_token, verify_agent_token from app.core.agent_tokens import (
generate_agent_token,
hash_agent_token,
verify_agent_token,
)
from app.core.time import utcnow from app.core.time import utcnow
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, openclaw_call from app.integrations.openclaw_gateway import OpenClawGatewayError, openclaw_call
@@ -49,6 +56,31 @@ _TRANSIENT_GATEWAY_ERROR_MARKERS = (
) )
T = TypeVar("T") T = TypeVar("T")
_SECURE_RANDOM = random.SystemRandom()
_RUNTIME_TYPE_REFERENCES = (Awaitable, Callable, AsyncSession, Gateway, User, UUID)
@dataclass(frozen=True)
class GatewayTemplateSyncOptions:
"""Runtime options controlling gateway template synchronization."""
user: User | None
include_main: bool = True
reset_sessions: bool = False
rotate_tokens: bool = False
force_bootstrap: bool = False
board_id: UUID | None = None
@dataclass(frozen=True)
class _SyncContext:
"""Shared state passed to sync helper functions."""
session: AsyncSession
gateway: Gateway
config: GatewayClientConfig
backoff: _GatewayBackoff
options: GatewayTemplateSyncOptions
def _slugify(value: str) -> str: def _slugify(value: str) -> str:
@@ -70,7 +102,10 @@ def _is_transient_gateway_error(exc: Exception) -> bool:
def _gateway_timeout_message(exc: OpenClawGatewayError) -> str: def _gateway_timeout_message(exc: OpenClawGatewayError) -> str:
return f"Gateway unreachable after 10 minutes (template sync timeout). Last error: {exc}" return (
"Gateway unreachable after 10 minutes (template sync timeout). "
f"Last error: {exc}"
)
class _GatewayBackoff: class _GatewayBackoff:
@@ -91,16 +126,25 @@ class _GatewayBackoff:
def reset(self) -> None: def reset(self) -> None:
self._delay_s = self._base_delay_s self._delay_s = self._base_delay_s
async def _attempt(
self,
fn: Callable[[], Awaitable[T]],
) -> tuple[T | None, OpenClawGatewayError | None]:
try:
return await fn(), None
except OpenClawGatewayError as exc:
return None, exc
async def run(self, fn: Callable[[], Awaitable[T]]) -> T: async def run(self, fn: Callable[[], Awaitable[T]]) -> T:
# Use per-call deadlines so long-running syncs can still tolerate a later # Use per-call deadlines so long-running syncs can still tolerate a later
# gateway restart without having an already-expired retry window. # gateway restart without having an already-expired retry window.
deadline_s = asyncio.get_running_loop().time() + self._timeout_s deadline_s = asyncio.get_running_loop().time() + self._timeout_s
while True: while True:
try: value, error = await self._attempt(fn)
value = await fn() if error is not None:
except OpenClawGatewayError as exc: exc = error
if not _is_transient_gateway_error(exc): if not _is_transient_gateway_error(exc):
raise raise exc
now = asyncio.get_running_loop().time() now = asyncio.get_running_loop().time()
remaining = deadline_s - now remaining = deadline_s - now
if remaining <= 0: if remaining <= 0:
@@ -108,11 +152,14 @@ class _GatewayBackoff:
sleep_s = min(self._delay_s, remaining) sleep_s = min(self._delay_s, remaining)
if self._jitter: if self._jitter:
sleep_s *= 1.0 + random.uniform(-self._jitter, self._jitter) sleep_s *= 1.0 + _SECURE_RANDOM.uniform(
-self._jitter,
self._jitter,
)
sleep_s = max(0.0, min(sleep_s, remaining)) sleep_s = max(0.0, min(sleep_s, remaining))
await asyncio.sleep(sleep_s) await asyncio.sleep(sleep_s)
self._delay_s = min(self._delay_s * 2.0, self._max_delay_s) self._delay_s = min(self._delay_s * 2.0, self._max_delay_s)
else: continue
self.reset() self.reset()
return value return value
@@ -138,8 +185,7 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None:
return agent_id or None return agent_id or None
def _extract_agent_id(payload: object) -> str | None: def _extract_agent_id_from_list(items: object) -> str | None:
def _from_list(items: object) -> str | None:
if not isinstance(items, list): if not isinstance(items, list):
return None return None
for item in items: for item in items:
@@ -153,8 +199,11 @@ def _extract_agent_id(payload: object) -> str | None:
return raw.strip() return raw.strip()
return None return None
def _extract_agent_id(payload: object) -> str | None:
"""Extract a default gateway agent id from common list payload shapes."""
if isinstance(payload, list): if isinstance(payload, list):
return _from_list(payload) return _extract_agent_id_from_list(payload)
if not isinstance(payload, dict): if not isinstance(payload, dict):
return None return None
for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"): for key in ("defaultId", "default_id", "defaultAgentId", "default_agent_id"):
@@ -162,7 +211,7 @@ def _extract_agent_id(payload: object) -> str | None:
if isinstance(raw, str) and raw.strip(): if isinstance(raw, str) and raw.strip():
return raw.strip() return raw.strip()
for key in ("agents", "items", "list", "data"): for key in ("agents", "items", "list", "data"):
agent_id = _from_list(payload.get(key)) agent_id = _extract_agent_id_from_list(payload.get(key))
if agent_id: if agent_id:
return agent_id return agent_id
return None return None
@@ -212,9 +261,6 @@ async def _get_agent_file(
if isinstance(payload, str): if isinstance(payload, str):
return payload return payload
if isinstance(payload, dict): if isinstance(payload, dict):
# Common shapes:
# - {"name": "...", "content": "..."}
# - {"file": {"name": "...", "content": "..." }}
content = payload.get("content") content = payload.get("content")
if isinstance(content, str): if isinstance(content, str):
return content return content
@@ -291,18 +337,53 @@ async def _paused_board_ids(session: AsyncSession, board_ids: list[UUID]) -> set
return paused return paused
async def sync_gateway_templates( def _append_sync_error(
session: AsyncSession, result: GatewayTemplatesSyncResult,
*,
message: str,
agent: Agent | None = None,
board: Board | None = None,
) -> None:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id if agent else None,
agent_name=agent.name if agent else None,
board_id=board.id if board else None,
message=message,
),
)
async def _rotate_agent_token(session: AsyncSession, agent: Agent) -> str:
token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(token)
agent.updated_at = utcnow()
session.add(agent)
await session.commit()
await session.refresh(agent)
return token
async def _ping_gateway(ctx: _SyncContext, result: GatewayTemplatesSyncResult) -> bool:
try:
async def _do_ping() -> object:
return await openclaw_call("agents.list", config=ctx.config)
await ctx.backoff.run(_do_ping)
except (TimeoutError, OpenClawGatewayError) as exc:
_append_sync_error(result, message=str(exc))
return False
else:
return True
def _base_result(
gateway: Gateway, gateway: Gateway,
*, *,
user: User | None, include_main: bool,
include_main: bool = True, reset_sessions: bool,
reset_sessions: bool = False,
rotate_tokens: bool = False,
force_bootstrap: bool = False,
board_id: UUID | None = None,
) -> GatewayTemplatesSyncResult: ) -> GatewayTemplatesSyncResult:
result = GatewayTemplatesSyncResult( return GatewayTemplatesSyncResult(
gateway_id=gateway.id, gateway_id=gateway.id,
include_main=include_main, include_main=include_main,
reset_sessions=reset_sessions, reset_sessions=reset_sessions,
@@ -310,45 +391,239 @@ async def sync_gateway_templates(
agents_skipped=0, agents_skipped=0,
main_updated=False, main_updated=False,
) )
def _boards_by_id(
boards: list[Board],
*,
board_id: UUID | None,
) -> dict[UUID, Board] | None:
boards_by_id = {board.id: board for board in boards}
if board_id is None:
return boards_by_id
board = boards_by_id.get(board_id)
if board is None:
return None
return {board_id: board}
async def _resolve_agent_auth_token(
ctx: _SyncContext,
result: GatewayTemplatesSyncResult,
agent: Agent,
board: Board | None,
*,
agent_gateway_id: str,
) -> tuple[str | None, bool]:
try:
auth_token = await _get_existing_auth_token(
agent_gateway_id=agent_gateway_id,
config=ctx.config,
backoff=ctx.backoff,
)
except TimeoutError as exc:
_append_sync_error(result, agent=agent, board=board, message=str(exc))
return None, True
if not auth_token:
if not ctx.options.rotate_tokens:
result.agents_skipped += 1
_append_sync_error(
result,
agent=agent,
board=board,
message=(
"Skipping agent: unable to read AUTH_TOKEN from TOOLS.md "
"(run with rotate_tokens=true to re-key)."
),
)
return None, False
auth_token = await _rotate_agent_token(ctx.session, agent)
if agent.agent_token_hash and not verify_agent_token(
auth_token,
agent.agent_token_hash,
):
if ctx.options.rotate_tokens:
auth_token = await _rotate_agent_token(ctx.session, agent)
else:
_append_sync_error(
result,
agent=agent,
board=board,
message=(
"Warning: AUTH_TOKEN in TOOLS.md does not match backend "
"token hash (agent auth may be broken)."
),
)
return auth_token, False
async def _sync_one_agent(
ctx: _SyncContext,
result: GatewayTemplatesSyncResult,
agent: Agent,
board: Board,
) -> bool:
auth_token, fatal = await _resolve_agent_auth_token(
ctx,
result,
agent,
board,
agent_gateway_id=_gateway_agent_id(agent),
)
if fatal:
return True
if not auth_token:
return False
try:
async def _do_provision() -> None:
await provision_agent(
agent,
board,
ctx.gateway,
auth_token,
ctx.options.user,
action="update",
force_bootstrap=ctx.options.force_bootstrap,
reset_session=ctx.options.reset_sessions,
)
await _with_gateway_retry(_do_provision, backoff=ctx.backoff)
result.agents_updated += 1
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
result.agents_skipped += 1
_append_sync_error(result, agent=agent, board=board, message=str(exc))
return True
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
result.agents_skipped += 1
_append_sync_error(
result,
agent=agent,
board=board,
message=f"Failed to sync templates: {exc}",
)
return False
else:
return False
async def _sync_main_agent(
ctx: _SyncContext,
result: GatewayTemplatesSyncResult,
) -> bool:
main_agent = (
await Agent.objects.all()
.filter(col(Agent.openclaw_session_id) == ctx.gateway.main_session_key)
.first(ctx.session)
)
if main_agent is None:
_append_sync_error(
result,
message=(
"Gateway main agent record not found; "
"skipping main agent template sync."
),
)
return True
try:
main_gateway_agent_id = await _gateway_default_agent_id(
ctx.config,
fallback_session_key=ctx.gateway.main_session_key,
backoff=ctx.backoff,
)
except TimeoutError as exc:
_append_sync_error(result, agent=main_agent, message=str(exc))
return True
if not main_gateway_agent_id:
_append_sync_error(
result,
agent=main_agent,
message="Unable to resolve gateway default agent id for main agent.",
)
return True
token, fatal = await _resolve_agent_auth_token(
ctx,
result,
main_agent,
board=None,
agent_gateway_id=main_gateway_agent_id,
)
if fatal:
return True
if not token:
_append_sync_error(
result,
agent=main_agent,
message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.",
)
return True
stop_sync = False
try:
async def _do_provision_main() -> None:
await provision_main_agent(
main_agent,
ctx.gateway,
token,
ctx.options.user,
action="update",
force_bootstrap=ctx.options.force_bootstrap,
reset_session=ctx.options.reset_sessions,
)
await _with_gateway_retry(_do_provision_main, backoff=ctx.backoff)
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
_append_sync_error(result, agent=main_agent, message=str(exc))
stop_sync = True
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
_append_sync_error(
result,
agent=main_agent,
message=f"Failed to sync main agent templates: {exc}",
)
else:
result.main_updated = True
return stop_sync
async def sync_gateway_templates(
session: AsyncSession,
gateway: Gateway,
options: GatewayTemplateSyncOptions,
) -> GatewayTemplatesSyncResult:
"""Synchronize AGENTS/TOOLS/etc templates to gateway-connected agents."""
result = _base_result(
gateway,
include_main=options.include_main,
reset_sessions=options.reset_sessions,
)
if not gateway.url: if not gateway.url:
result.errors.append( _append_sync_error(
GatewayTemplatesSyncError(message="Gateway URL is not configured for this gateway.") result,
message="Gateway URL is not configured for this gateway.",
) )
return result return result
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) ctx = _SyncContext(
backoff = _GatewayBackoff(timeout_s=10 * 60) session=session,
gateway=gateway,
# First, wait for the gateway to be reachable (e.g. while it is restarting). config=GatewayClientConfig(url=gateway.url, token=gateway.token),
try: backoff=_GatewayBackoff(timeout_s=10 * 60),
options=options,
async def _do_ping() -> object: )
return await openclaw_call("agents.list", config=client_config) if not await _ping_gateway(ctx, result):
await backoff.run(_do_ping)
except TimeoutError as exc:
result.errors.append(GatewayTemplatesSyncError(message=str(exc)))
return result
except OpenClawGatewayError as exc:
result.errors.append(GatewayTemplatesSyncError(message=str(exc)))
return result return result
boards = await Board.objects.filter_by(gateway_id=gateway.id).all(session) boards = await Board.objects.filter_by(gateway_id=gateway.id).all(session)
boards_by_id = {board.id: board for board in boards} boards_by_id = _boards_by_id(boards, board_id=options.board_id)
if board_id is not None: if boards_by_id is None:
board = boards_by_id.get(board_id) _append_sync_error(
if board is None: result,
result.errors.append(
GatewayTemplatesSyncError(
board_id=board_id,
message="Board does not belong to this gateway.", message="Board does not belong to this gateway.",
) )
)
return result return result
boards_by_id = {board_id: board}
paused_board_ids = await _paused_board_ids(session, list(boards_by_id.keys())) paused_board_ids = await _paused_board_ids(session, list(boards_by_id.keys()))
if boards_by_id: if boards_by_id:
agents = await ( agents = await (
Agent.objects.by_field_in("board_id", list(boards_by_id.keys())) Agent.objects.by_field_in("board_id", list(boards_by_id.keys()))
@@ -358,251 +633,24 @@ async def sync_gateway_templates(
else: else:
agents = [] agents = []
stop_sync = False
for agent in agents: for agent in agents:
board = boards_by_id.get(agent.board_id) if agent.board_id is not None else None board = boards_by_id.get(agent.board_id) if agent.board_id is not None else None
if board is None: if board is None:
result.agents_skipped += 1 result.agents_skipped += 1
result.errors.append( _append_sync_error(
GatewayTemplatesSyncError( result,
agent_id=agent.id, agent=agent,
agent_name=agent.name,
board_id=agent.board_id,
message="Skipping agent: board not found for agent.", message="Skipping agent: board not found for agent.",
) )
)
continue continue
if board.id in paused_board_ids: if board.id in paused_board_ids:
result.agents_skipped += 1 result.agents_skipped += 1
continue continue
stop_sync = await _sync_one_agent(ctx, result, agent, board)
if stop_sync:
break
agent_gateway_id = _gateway_agent_id(agent) if not stop_sync and options.include_main:
try: await _sync_main_agent(ctx, result)
auth_token = await _get_existing_auth_token(
agent_gateway_id=agent_gateway_id,
config=client_config,
backoff=backoff,
)
except TimeoutError as exc:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id,
agent_name=agent.name,
board_id=board.id,
message=str(exc),
)
)
return result
if not auth_token:
if not rotate_tokens:
result.agents_skipped += 1
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id,
agent_name=agent.name,
board_id=board.id,
message="Skipping agent: unable to read AUTH_TOKEN from TOOLS.md (run with rotate_tokens=true to re-key).",
)
)
continue
raw_token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(raw_token)
agent.updated_at = utcnow()
session.add(agent)
await session.commit()
await session.refresh(agent)
auth_token = raw_token
if agent.agent_token_hash and not verify_agent_token(auth_token, agent.agent_token_hash):
# Do not block template sync on token drift; optionally re-key.
if rotate_tokens:
raw_token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(raw_token)
agent.updated_at = utcnow()
session.add(agent)
await session.commit()
await session.refresh(agent)
auth_token = raw_token
else:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id,
agent_name=agent.name,
board_id=board.id,
message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (agent auth may be broken).",
)
)
try:
agent_item: Agent = agent
board_item: Board = board
auth_token_value: str = auth_token
async def _do_provision(
agent_item: Agent = agent_item,
board_item: Board = board_item,
auth_token_value: str = auth_token_value,
) -> None:
await provision_agent(
agent_item,
board_item,
gateway,
auth_token_value,
user,
action="update",
force_bootstrap=force_bootstrap,
reset_session=reset_sessions,
)
await _with_gateway_retry(_do_provision, backoff=backoff)
result.agents_updated += 1
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
result.agents_skipped += 1
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id,
agent_name=agent.name,
board_id=board.id,
message=str(exc),
)
)
return result
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
result.agents_skipped += 1
result.errors.append(
GatewayTemplatesSyncError(
agent_id=agent.id,
agent_name=agent.name,
board_id=board.id,
message=f"Failed to sync templates: {exc}",
)
)
if include_main:
main_agent = (
await Agent.objects.all()
.filter(col(Agent.openclaw_session_id) == gateway.main_session_key)
.first(session)
)
if main_agent is None:
result.errors.append(
GatewayTemplatesSyncError(
message="Gateway main agent record not found; skipping main agent template sync.",
)
)
return result
try:
main_gateway_agent_id = await _gateway_default_agent_id(
client_config,
fallback_session_key=gateway.main_session_key,
backoff=backoff,
)
except TimeoutError as exc:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message=str(exc),
)
)
return result
if not main_gateway_agent_id:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message="Unable to resolve gateway default agent id for main agent.",
)
)
return result
try:
main_token = await _get_existing_auth_token(
agent_gateway_id=main_gateway_agent_id,
config=client_config,
backoff=backoff,
)
except TimeoutError as exc:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message=str(exc),
)
)
return result
if not main_token:
if rotate_tokens:
raw_token = generate_agent_token()
main_agent.agent_token_hash = hash_agent_token(raw_token)
main_agent.updated_at = utcnow()
session.add(main_agent)
await session.commit()
await session.refresh(main_agent)
main_token = raw_token
else:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message="Skipping main agent: unable to read AUTH_TOKEN from TOOLS.md.",
)
)
return result
if main_agent.agent_token_hash and not verify_agent_token(
main_token, main_agent.agent_token_hash
):
if rotate_tokens:
raw_token = generate_agent_token()
main_agent.agent_token_hash = hash_agent_token(raw_token)
main_agent.updated_at = utcnow()
session.add(main_agent)
await session.commit()
await session.refresh(main_agent)
main_token = raw_token
else:
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message="Warning: AUTH_TOKEN in TOOLS.md does not match backend token hash (main agent auth may be broken).",
)
)
try:
async def _do_provision_main() -> None:
await provision_main_agent(
main_agent,
gateway,
main_token,
user,
action="update",
force_bootstrap=force_bootstrap,
reset_session=reset_sessions,
)
await _with_gateway_retry(_do_provision_main, backoff=backoff)
result.main_updated = True
except TimeoutError as exc: # pragma: no cover - gateway/network dependent
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message=str(exc),
)
)
return result
except (OSError, RuntimeError, ValueError) as exc: # pragma: no cover
result.errors.append(
GatewayTemplatesSyncError(
agent_id=main_agent.id,
agent_name=main_agent.name,
message=f"Failed to sync main agent templates: {exc}",
)
)
return result return result

View File

@@ -1,3 +1,5 @@
# ruff: noqa
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass

View File

@@ -1,3 +1,5 @@
# ruff: noqa
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass

View File

@@ -1,3 +1,5 @@
# ruff: noqa
from __future__ import annotations from __future__ import annotations
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException

View File

@@ -1,3 +1,5 @@
# ruff: noqa
import hashlib import hashlib
from app.services.lead_policy import ( from app.services.lead_policy import (

View File

@@ -1,3 +1,5 @@
# ruff: noqa
from app.models.agents import Agent from app.models.agents import Agent
from app.services.mentions import extract_mentions, matches_agent_mention from app.services.mentions import extract_mentions, matches_agent_mention

View File

@@ -1,3 +1,5 @@
# ruff: noqa
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field

View File

@@ -1,3 +1,5 @@
# ruff: noqa
from __future__ import annotations from __future__ import annotations
import pytest import pytest

View File

@@ -1,3 +1,5 @@
# ruff: noqa
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field

View File

@@ -1,3 +1,5 @@
# ruff: noqa
from __future__ import annotations from __future__ import annotations
from uuid import UUID, uuid4 from uuid import UUID, uuid4