refactor: update module docstrings for clarity and consistency

This commit is contained in:
Abhimanyu Saharan
2026-02-09 15:49:50 +05:30
parent 78bb08d4a3
commit 7ca1899d9f
99 changed files with 2345 additions and 855 deletions

View File

@@ -0,0 +1 @@
"""OpenClaw Mission Control backend application package."""

View File

@@ -0,0 +1 @@
"""API router modules for the OpenClaw Mission Control backend."""

View File

@@ -1,13 +1,14 @@
"""Agent-scoped API routes for board operations and gateway coordination."""
from __future__ import annotations from __future__ import annotations
import re import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, cast from typing import TYPE_CHECKING, Any, cast
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import SQLModel, col, select from sqlmodel import SQLModel, col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api import agents as agents_api from app.api import agents as agents_api
from app.api import approvals as approvals_api from app.api import approvals as approvals_api
@@ -27,11 +28,7 @@ from app.integrations.openclaw_gateway import (
openclaw_call, openclaw_call,
send_message, send_message,
) )
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent from app.models.agents import Agent
from app.models.approvals import Approval
from app.models.board_memory import BoardMemory
from app.models.board_onboarding import BoardOnboardingSession
from app.models.boards import Board from app.models.boards import Board
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.models.task_dependencies import TaskDependency from app.models.task_dependencies import TaskDependency
@@ -58,7 +55,13 @@ from app.schemas.gateway_coordination import (
GatewayMainAskUserResponse, GatewayMainAskUserResponse,
) )
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate from app.schemas.tasks import (
TaskCommentCreate,
TaskCommentRead,
TaskCreate,
TaskRead,
TaskUpdate,
)
from app.services.activity_log import record_activity from app.services.activity_log import record_activity
from app.services.board_leads import ensure_board_lead_agent from app.services.board_leads import ensure_board_lead_agent
from app.services.task_dependencies import ( from app.services.task_dependencies import (
@@ -67,11 +70,27 @@ from app.services.task_dependencies import (
validate_dependency_update, validate_dependency_update,
) )
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.activity_events import ActivityEvent
from app.models.approvals import Approval
from app.models.board_memory import BoardMemory
from app.models.board_onboarding import BoardOnboardingSession
router = APIRouter(prefix="/agent", tags=["agent"]) router = APIRouter(prefix="/agent", tags=["agent"])
_AGENT_SESSION_PREFIX = "agent:" _AGENT_SESSION_PREFIX = "agent:"
_SESSION_KEY_PARTS_MIN = 2 _SESSION_KEY_PARTS_MIN = 2
_LEAD_SESSION_KEY_MISSING = "Lead agent has no session key" _LEAD_SESSION_KEY_MISSING = "Lead agent has no session key"
SESSION_DEP = Depends(get_session)
AGENT_CTX_DEP = Depends(get_agent_auth_context)
BOARD_DEP = Depends(get_board_or_404)
TASK_DEP = Depends(get_task_or_404)
BOARD_ID_QUERY = Query(default=None)
TASK_STATUS_QUERY = Query(default=None, alias="status")
IS_CHAT_QUERY = Query(default=None)
APPROVAL_STATUS_QUERY = Query(default=None, alias="status")
def _gateway_agent_id(agent: Agent) -> str: def _gateway_agent_id(agent: Agent) -> str:
@@ -87,6 +106,8 @@ def _gateway_agent_id(agent: Agent) -> str:
class SoulUpdateRequest(SQLModel): class SoulUpdateRequest(SQLModel):
"""Payload for updating an agent SOUL document."""
content: str content: str
source_url: str | None = None source_url: str | None = None
reason: str | None = None reason: str | None = None
@@ -124,9 +145,12 @@ async def _require_gateway_main(
session_key = (agent.openclaw_session_id or "").strip() session_key = (agent.openclaw_session_id or "").strip()
if not session_key: if not session_key:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Agent missing session key" status_code=status.HTTP_403_FORBIDDEN,
detail="Agent missing session key",
) )
gateway = await Gateway.objects.filter_by(main_session_key=session_key).first(session) gateway = await Gateway.objects.filter_by(main_session_key=session_key).first(
session,
)
if gateway is None: if gateway is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@@ -148,7 +172,9 @@ async def _require_gateway_board(
) -> Board: ) -> Board:
board = await Board.objects.by_id(board_id).first(session) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found",
)
if board.gateway_id != gateway.id: if board.gateway_id != gateway.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return board return board
@@ -156,9 +182,10 @@ async def _require_gateway_board(
@router.get("/boards", response_model=DefaultLimitOffsetPage[BoardRead]) @router.get("/boards", response_model=DefaultLimitOffsetPage[BoardRead])
async def list_boards( async def list_boards(
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[BoardRead]: ) -> DefaultLimitOffsetPage[BoardRead]:
"""List boards visible to the authenticated agent."""
statement = select(Board) statement = select(Board)
if agent_ctx.agent.board_id: if agent_ctx.agent.board_id:
statement = statement.where(col(Board.id) == agent_ctx.agent.board_id) statement = statement.where(col(Board.id) == agent_ctx.agent.board_id)
@@ -168,19 +195,21 @@ async def list_boards(
@router.get("/boards/{board_id}", response_model=BoardRead) @router.get("/boards/{board_id}", response_model=BoardRead)
def get_board( def get_board(
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> Board: ) -> Board:
"""Return a board if the authenticated agent can access it."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
return board return board
@router.get("/agents", response_model=DefaultLimitOffsetPage[AgentRead]) @router.get("/agents", response_model=DefaultLimitOffsetPage[AgentRead])
async def list_agents( async def list_agents(
board_id: UUID | None = Query(default=None), board_id: UUID | None = BOARD_ID_QUERY,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[AgentRead]: ) -> DefaultLimitOffsetPage[AgentRead]:
"""List agents, optionally filtered to a board."""
statement = select(Agent) statement = select(Agent)
if agent_ctx.agent.board_id: if agent_ctx.agent.board_id:
if board_id and board_id != agent_ctx.agent.board_id: if board_id and board_id != agent_ctx.agent.board_id:
@@ -188,13 +217,19 @@ async def list_agents(
statement = statement.where(Agent.board_id == agent_ctx.agent.board_id) statement = statement.where(Agent.board_id == agent_ctx.agent.board_id)
elif board_id: elif board_id:
statement = statement.where(Agent.board_id == board_id) statement = statement.where(Agent.board_id == board_id)
main_session_keys = await agents_api._get_gateway_main_session_keys(session) get_gateway_main_session_keys = (
agents_api._get_gateway_main_session_keys # noqa: SLF001
)
to_agent_read = agents_api._to_agent_read # noqa: SLF001
with_computed_status = agents_api._with_computed_status # noqa: SLF001
main_session_keys = await get_gateway_main_session_keys(session)
statement = statement.order_by(col(Agent.created_at).desc()) statement = statement.order_by(col(Agent.created_at).desc())
def _transform(items: Sequence[Any]) -> Sequence[Any]: def _transform(items: Sequence[Any]) -> Sequence[Any]:
agents = cast(Sequence[Agent], items) agents = cast(Sequence[Agent], items)
return [ return [
agents_api._to_agent_read(agents_api._with_computed_status(agent), main_session_keys) to_agent_read(with_computed_status(agent), main_session_keys)
for agent in agents for agent in agents
] ]
@@ -202,14 +237,15 @@ async def list_agents(
@router.get("/boards/{board_id}/tasks", response_model=DefaultLimitOffsetPage[TaskRead]) @router.get("/boards/{board_id}/tasks", response_model=DefaultLimitOffsetPage[TaskRead])
async def list_tasks( async def list_tasks( # noqa: PLR0913
status_filter: str | None = Query(default=None, alias="status"), status_filter: str | None = TASK_STATUS_QUERY,
assigned_agent_id: UUID | None = None, assigned_agent_id: UUID | None = None,
unassigned: bool | None = None, unassigned: bool | None = None,
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[TaskRead]: ) -> DefaultLimitOffsetPage[TaskRead]:
"""List tasks on a board with optional status and assignment filters."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
return await tasks_api.list_tasks( return await tasks_api.list_tasks(
status_filter=status_filter, status_filter=status_filter,
@@ -224,10 +260,11 @@ async def list_tasks(
@router.post("/boards/{board_id}/tasks", response_model=TaskRead) @router.post("/boards/{board_id}/tasks", response_model=TaskRead)
async def create_task( async def create_task(
payload: TaskCreate, payload: TaskCreate,
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> TaskRead: ) -> TaskRead:
"""Create a task on the board as the lead agent."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead: if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -250,7 +287,9 @@ async def create_task(
board_id=board.id, board_id=board.id,
dependency_ids=normalized_deps, dependency_ids=normalized_deps,
) )
blocked_by = blocked_by_dependency_ids(dependency_ids=normalized_deps, status_by_id=dep_status) blocked_by = blocked_by_dependency_ids(
dependency_ids=normalized_deps, status_by_id=dep_status,
)
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"): if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
raise HTTPException( raise HTTPException(
@@ -280,7 +319,7 @@ async def create_task(
board_id=board.id, board_id=board.id,
task_id=task.id, task_id=task.id,
depends_on_task_id=dep_id, depends_on_task_id=dep_id,
) ),
) )
await session.commit() await session.commit()
await session.refresh(task) await session.refresh(task)
@@ -293,9 +332,14 @@ async def create_task(
) )
await session.commit() await session.commit()
if task.assigned_agent_id: if task.assigned_agent_id:
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session) assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
session,
)
if assigned_agent: if assigned_agent:
await tasks_api._notify_agent_on_task_assign( notify_agent_on_task_assign = (
tasks_api._notify_agent_on_task_assign # noqa: SLF001
)
await notify_agent_on_task_assign(
session=session, session=session,
board=board, board=board,
task=task, task=task,
@@ -306,18 +350,23 @@ async def create_task(
"depends_on_task_ids": normalized_deps, "depends_on_task_ids": normalized_deps,
"blocked_by_task_ids": blocked_by, "blocked_by_task_ids": blocked_by,
"is_blocked": bool(blocked_by), "is_blocked": bool(blocked_by),
} },
) )
@router.patch("/boards/{board_id}/tasks/{task_id}", response_model=TaskRead) @router.patch("/boards/{board_id}/tasks/{task_id}", response_model=TaskRead)
async def update_task( async def update_task(
payload: TaskUpdate, payload: TaskUpdate,
task: Task = Depends(get_task_or_404), task: Task = TASK_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> TaskRead: ) -> TaskRead:
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id: """Update a task after board-level access checks."""
if (
agent_ctx.agent.board_id
and task.board_id
and agent_ctx.agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return await tasks_api.update_task( return await tasks_api.update_task(
payload=payload, payload=payload,
@@ -332,11 +381,16 @@ async def update_task(
response_model=DefaultLimitOffsetPage[TaskCommentRead], response_model=DefaultLimitOffsetPage[TaskCommentRead],
) )
async def list_task_comments( async def list_task_comments(
task: Task = Depends(get_task_or_404), task: Task = TASK_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[TaskCommentRead]: ) -> DefaultLimitOffsetPage[TaskCommentRead]:
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id: """List comments for a task visible to the authenticated agent."""
if (
agent_ctx.agent.board_id
and task.board_id
and agent_ctx.agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return await tasks_api.list_task_comments( return await tasks_api.list_task_comments(
task=task, task=task,
@@ -344,14 +398,21 @@ async def list_task_comments(
) )
@router.post("/boards/{board_id}/tasks/{task_id}/comments", response_model=TaskCommentRead) @router.post(
"/boards/{board_id}/tasks/{task_id}/comments", response_model=TaskCommentRead,
)
async def create_task_comment( async def create_task_comment(
payload: TaskCommentCreate, payload: TaskCommentCreate,
task: Task = Depends(get_task_or_404), task: Task = TASK_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> ActivityEvent: ) -> ActivityEvent:
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id: """Create a task comment on behalf of the authenticated agent."""
if (
agent_ctx.agent.board_id
and task.board_id
and agent_ctx.agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return await tasks_api.create_task_comment( return await tasks_api.create_task_comment(
payload=payload, payload=payload,
@@ -361,13 +422,16 @@ async def create_task_comment(
) )
@router.get("/boards/{board_id}/memory", response_model=DefaultLimitOffsetPage[BoardMemoryRead]) @router.get(
"/boards/{board_id}/memory", response_model=DefaultLimitOffsetPage[BoardMemoryRead],
)
async def list_board_memory( async def list_board_memory(
is_chat: bool | None = Query(default=None), is_chat: bool | None = IS_CHAT_QUERY,
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[BoardMemoryRead]: ) -> DefaultLimitOffsetPage[BoardMemoryRead]:
"""List board memory entries with optional chat filtering."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
return await board_memory_api.list_board_memory( return await board_memory_api.list_board_memory(
is_chat=is_chat, is_chat=is_chat,
@@ -380,10 +444,11 @@ async def list_board_memory(
@router.post("/boards/{board_id}/memory", response_model=BoardMemoryRead) @router.post("/boards/{board_id}/memory", response_model=BoardMemoryRead)
async def create_board_memory( async def create_board_memory(
payload: BoardMemoryCreate, payload: BoardMemoryCreate,
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> BoardMemory: ) -> BoardMemory:
"""Create a board memory entry."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
return await board_memory_api.create_board_memory( return await board_memory_api.create_board_memory(
payload=payload, payload=payload,
@@ -398,11 +463,12 @@ async def create_board_memory(
response_model=DefaultLimitOffsetPage[ApprovalRead], response_model=DefaultLimitOffsetPage[ApprovalRead],
) )
async def list_approvals( async def list_approvals(
status_filter: ApprovalStatus | None = Query(default=None, alias="status"), status_filter: ApprovalStatus | None = APPROVAL_STATUS_QUERY,
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> DefaultLimitOffsetPage[ApprovalRead]: ) -> DefaultLimitOffsetPage[ApprovalRead]:
"""List approvals for a board."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
return await approvals_api.list_approvals( return await approvals_api.list_approvals(
status_filter=status_filter, status_filter=status_filter,
@@ -415,10 +481,11 @@ async def list_approvals(
@router.post("/boards/{board_id}/approvals", response_model=ApprovalRead) @router.post("/boards/{board_id}/approvals", response_model=ApprovalRead)
async def create_approval( async def create_approval(
payload: ApprovalCreate, payload: ApprovalCreate,
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> Approval: ) -> Approval:
"""Create a board approval request."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
return await approvals_api.create_approval( return await approvals_api.create_approval(
payload=payload, payload=payload,
@@ -431,10 +498,11 @@ async def create_approval(
@router.post("/boards/{board_id}/onboarding", response_model=BoardOnboardingRead) @router.post("/boards/{board_id}/onboarding", response_model=BoardOnboardingRead)
async def update_onboarding( async def update_onboarding(
payload: BoardOnboardingAgentUpdate, payload: BoardOnboardingAgentUpdate,
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> BoardOnboardingSession: ) -> BoardOnboardingSession:
"""Apply onboarding updates for a board."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
return await onboarding_api.agent_onboarding_update( return await onboarding_api.agent_onboarding_update(
payload=payload, payload=payload,
@@ -447,14 +515,17 @@ async def update_onboarding(
@router.post("/agents", response_model=AgentRead) @router.post("/agents", response_model=AgentRead)
async def create_agent( async def create_agent(
payload: AgentCreate, payload: AgentCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> AgentRead: ) -> AgentRead:
"""Create an agent on the caller's board."""
if not agent_ctx.agent.is_board_lead: if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if not agent_ctx.agent.board_id: if not agent_ctx.agent.board_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
payload = AgentCreate(**{**payload.model_dump(), "board_id": agent_ctx.agent.board_id}) payload = AgentCreate(
**{**payload.model_dump(), "board_id": agent_ctx.agent.board_id},
)
return await agents_api.create_agent( return await agents_api.create_agent(
payload=payload, payload=payload,
session=session, session=session,
@@ -466,10 +537,11 @@ async def create_agent(
async def nudge_agent( async def nudge_agent(
payload: AgentNudge, payload: AgentNudge,
agent_id: str, agent_id: str,
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> OkResponse: ) -> OkResponse:
"""Send a direct nudge message to a board agent."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead: if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -484,7 +556,9 @@ async def nudge_agent(
message = payload.message message = payload.message
config = await _gateway_config(session, board) config = await _gateway_config(session, board)
try: try:
await ensure_session(target.openclaw_session_id, config=config, label=target.name) await ensure_session(
target.openclaw_session_id, config=config, label=target.name,
)
await send_message( await send_message(
message, message,
session_key=target.openclaw_session_id, session_key=target.openclaw_session_id,
@@ -499,7 +573,9 @@ async def nudge_agent(
agent_id=agent_ctx.agent.id, agent_id=agent_ctx.agent.id,
) )
await session.commit() await session.commit()
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
record_activity( record_activity(
session, session,
event_type="agent.nudge.sent", event_type="agent.nudge.sent",
@@ -513,9 +589,10 @@ async def nudge_agent(
@router.post("/heartbeat", response_model=AgentRead) @router.post("/heartbeat", response_model=AgentRead)
async def agent_heartbeat( async def agent_heartbeat(
payload: AgentHeartbeatCreate, payload: AgentHeartbeatCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> AgentRead: ) -> AgentRead:
"""Record heartbeat status for the authenticated agent."""
# Heartbeats must apply to the authenticated agent; agent names are not unique. # Heartbeats must apply to the authenticated agent; agent names are not unique.
return await agents_api.heartbeat_agent( return await agents_api.heartbeat_agent(
agent_id=str(agent_ctx.agent.id), agent_id=str(agent_ctx.agent.id),
@@ -528,10 +605,11 @@ async def agent_heartbeat(
@router.get("/boards/{board_id}/agents/{agent_id}/soul", response_model=str) @router.get("/boards/{board_id}/agents/{agent_id}/soul", response_model=str)
async def get_agent_soul( async def get_agent_soul(
agent_id: str, agent_id: str,
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> str: ) -> str:
"""Fetch the target agent's SOUL.md content from the gateway."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead and str(agent_ctx.agent.id) != agent_id: if not agent_ctx.agent.is_board_lead and str(agent_ctx.agent.id) != agent_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -547,7 +625,9 @@ async def get_agent_soul(
config=config, config=config,
) )
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
if isinstance(payload, str): if isinstance(payload, str):
return payload return payload
if isinstance(payload, dict): if isinstance(payload, dict):
@@ -559,17 +639,20 @@ async def get_agent_soul(
nested = file_obj.get("content") nested = file_obj.get("content")
if isinstance(nested, str): if isinstance(nested, str):
return nested return nested
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Invalid gateway response") raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail="Invalid gateway response",
)
@router.put("/boards/{board_id}/agents/{agent_id}/soul", response_model=OkResponse) @router.put("/boards/{board_id}/agents/{agent_id}/soul", response_model=OkResponse)
async def update_agent_soul( async def update_agent_soul(
agent_id: str, agent_id: str,
payload: SoulUpdateRequest, payload: SoulUpdateRequest,
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> OkResponse: ) -> OkResponse:
"""Update an agent's SOUL.md content in DB and gateway."""
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
if not agent_ctx.agent.is_board_lead: if not agent_ctx.agent.is_board_lead:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -597,7 +680,9 @@ async def update_agent_soul(
config=config, config=config,
) )
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
reason = (payload.reason or "").strip() reason = (payload.reason or "").strip()
source_url = (payload.source_url or "").strip() source_url = (payload.source_url or "").strip()
note = f"SOUL.md updated for {target.name}." note = f"SOUL.md updated for {target.name}."
@@ -621,10 +706,11 @@ async def update_agent_soul(
) )
async def ask_user_via_gateway_main( async def ask_user_via_gateway_main(
payload: GatewayMainAskUserRequest, payload: GatewayMainAskUserRequest,
board: Board = Depends(get_board_or_404), board: Board = BOARD_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> GatewayMainAskUserResponse: ) -> GatewayMainAskUserResponse:
"""Route a lead's ask-user request through the gateway main agent."""
import json import json
_guard_board_access(agent_ctx, board) _guard_board_access(agent_ctx, board)
@@ -653,7 +739,9 @@ async def ask_user_via_gateway_main(
correlation = payload.correlation_id.strip() if payload.correlation_id else "" correlation = payload.correlation_id.strip() if payload.correlation_id else ""
correlation_line = f"Correlation ID: {correlation}\n" if correlation else "" correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
preferred_channel = (payload.preferred_channel or "").strip() preferred_channel = (payload.preferred_channel or "").strip()
channel_line = f"Preferred channel: {preferred_channel}\n" if preferred_channel else "" channel_line = (
f"Preferred channel: {preferred_channel}\n" if preferred_channel else ""
)
tags = payload.reply_tags or ["gateway_main", "user_reply"] tags = payload.reply_tags or ["gateway_main", "user_reply"]
tags_json = json.dumps(tags) tags_json = json.dumps(tags)
@@ -668,9 +756,12 @@ async def ask_user_via_gateway_main(
f"{correlation_line}" f"{correlation_line}"
f"{channel_line}\n" f"{channel_line}\n"
f"{payload.content.strip()}\n\n" f"{payload.content.strip()}\n\n"
"Please reach the user via your configured OpenClaw channel(s) (Slack/SMS/etc).\n" "Please reach the user via your configured OpenClaw channel(s) "
"If you cannot reach them there, post the question in Mission Control board chat as a fallback.\n\n" "(Slack/SMS/etc).\n"
"When you receive the answer, reply in Mission Control by writing a NON-chat memory item on this board:\n" "If you cannot reach them there, post the question in Mission Control "
"board chat as a fallback.\n\n"
"When you receive the answer, reply in Mission Control by writing a "
"NON-chat memory item on this board:\n"
f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n" f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n"
f'Body: {{"content":"<answer>","tags":{tags_json},"source":"{reply_source}"}}\n' f'Body: {{"content":"<answer>","tags":{tags_json},"source":"{reply_source}"}}\n'
"Do NOT reply in OpenClaw chat." "Do NOT reply in OpenClaw chat."
@@ -678,7 +769,9 @@ async def ask_user_via_gateway_main(
try: try:
await ensure_session(main_session_key, config=config, label="Main Agent") await ensure_session(main_session_key, config=config, label="Main Agent")
await send_message(message, session_key=main_session_key, config=config, deliver=True) await send_message(
message, session_key=main_session_key, config=config, deliver=True,
)
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
record_activity( record_activity(
session, session,
@@ -687,7 +780,9 @@ async def ask_user_via_gateway_main(
agent_id=agent_ctx.agent.id, agent_id=agent_ctx.agent.id,
) )
await session.commit() await session.commit()
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
record_activity( record_activity(
session, session,
@@ -696,7 +791,9 @@ async def ask_user_via_gateway_main(
agent_id=agent_ctx.agent.id, agent_id=agent_ctx.agent.id,
) )
main_agent = await Agent.objects.filter_by(openclaw_session_id=main_session_key).first(session) main_agent = await Agent.objects.filter_by(
openclaw_session_id=main_session_key,
).first(session)
await session.commit() await session.commit()
@@ -714,9 +811,10 @@ async def ask_user_via_gateway_main(
async def message_gateway_board_lead( async def message_gateway_board_lead(
board_id: UUID, board_id: UUID,
payload: GatewayLeadMessageRequest, payload: GatewayLeadMessageRequest,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> GatewayLeadMessageResponse: ) -> GatewayLeadMessageResponse:
"""Send a gateway-main message to a single board lead agent."""
import json import json
gateway, config = await _require_gateway_main(session, agent_ctx.agent) gateway, config = await _require_gateway_main(session, agent_ctx.agent)
@@ -736,7 +834,11 @@ async def message_gateway_board_lead(
) )
base_url = settings.base_url or "http://localhost:8000" base_url = settings.base_url or "http://localhost:8000"
header = "GATEWAY MAIN QUESTION" if payload.kind == "question" else "GATEWAY MAIN HANDOFF" header = (
"GATEWAY MAIN QUESTION"
if payload.kind == "question"
else "GATEWAY MAIN HANDOFF"
)
correlation = payload.correlation_id.strip() if payload.correlation_id else "" correlation = payload.correlation_id.strip() if payload.correlation_id else ""
correlation_line = f"Correlation ID: {correlation}\n" if correlation else "" correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
tags = payload.reply_tags or ["gateway_main", "lead_reply"] tags = payload.reply_tags or ["gateway_main", "lead_reply"]
@@ -767,7 +869,9 @@ async def message_gateway_board_lead(
agent_id=agent_ctx.agent.id, agent_id=agent_ctx.agent.id,
) )
await session.commit() await session.commit()
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
record_activity( record_activity(
session, session,
@@ -791,9 +895,10 @@ async def message_gateway_board_lead(
) )
async def broadcast_gateway_lead_message( async def broadcast_gateway_lead_message(
payload: GatewayLeadBroadcastRequest, payload: GatewayLeadBroadcastRequest,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
agent_ctx: AgentAuthContext = Depends(get_agent_auth_context), agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> GatewayLeadBroadcastResponse: ) -> GatewayLeadBroadcastResponse:
"""Broadcast a gateway-main message to multiple board leads."""
import json import json
gateway, config = await _require_gateway_main(session, agent_ctx.agent) gateway, config = await _require_gateway_main(session, agent_ctx.agent)
@@ -808,7 +913,11 @@ async def broadcast_gateway_lead_message(
boards = list(await session.exec(statement)) boards = list(await session.exec(statement))
base_url = settings.base_url or "http://localhost:8000" base_url = settings.base_url or "http://localhost:8000"
header = "GATEWAY MAIN QUESTION" if payload.kind == "question" else "GATEWAY MAIN HANDOFF" header = (
"GATEWAY MAIN QUESTION"
if payload.kind == "question"
else "GATEWAY MAIN HANDOFF"
)
correlation = payload.correlation_id.strip() if payload.correlation_id else "" correlation = payload.correlation_id.strip() if payload.correlation_id else ""
correlation_line = f"Correlation ID: {correlation}\n" if correlation else "" correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
tags = payload.reply_tags or ["gateway_main", "lead_reply"] tags = payload.reply_tags or ["gateway_main", "lead_reply"]
@@ -819,7 +928,7 @@ async def broadcast_gateway_lead_message(
sent = 0 sent = 0
failed = 0 failed = 0
for board in boards: async def _send_to_board(board: Board) -> GatewayLeadBroadcastBoardResult:
try: try:
lead, _lead_created = await ensure_board_lead_agent( lead, _lead_created = await ensure_board_lead_agent(
session, session,
@@ -837,30 +946,34 @@ async def broadcast_gateway_lead_message(
f"From agent: {agent_ctx.agent.name}\n" f"From agent: {agent_ctx.agent.name}\n"
f"{correlation_line}\n" f"{correlation_line}\n"
f"{payload.content.strip()}\n\n" f"{payload.content.strip()}\n\n"
"Reply to the gateway main by writing a NON-chat memory item on this board:\n" "Reply to the gateway main by writing a NON-chat memory item "
"on this board:\n"
f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n" f"POST {base_url}/api/v1/agent/boards/{board.id}/memory\n"
f'Body: {{"content":"...","tags":{tags_json},"source":"{reply_source}"}}\n' f'Body: {{"content":"...","tags":{tags_json},'
f'"source":"{reply_source}"}}\n'
"Do NOT reply in OpenClaw chat." "Do NOT reply in OpenClaw chat."
) )
await ensure_session(lead_session_key, config=config, label=lead.name) await ensure_session(lead_session_key, config=config, label=lead.name)
await send_message(message, session_key=lead_session_key, config=config) await send_message(message, session_key=lead_session_key, config=config)
results.append( return GatewayLeadBroadcastBoardResult(
GatewayLeadBroadcastBoardResult( board_id=board.id,
board_id=board.id, lead_agent_id=lead.id,
lead_agent_id=lead.id, lead_agent_name=lead.name,
lead_agent_name=lead.name, ok=True,
ok=True,
)
) )
sent += 1
except (HTTPException, OpenClawGatewayError, ValueError) as exc: except (HTTPException, OpenClawGatewayError, ValueError) as exc:
results.append( return GatewayLeadBroadcastBoardResult(
GatewayLeadBroadcastBoardResult( board_id=board.id,
board_id=board.id, ok=False,
ok=False, error=str(exc),
error=str(exc),
)
) )
for board in boards:
board_result = await _send_to_board(board)
results.append(board_result)
if board_result.ok:
sent += 1
else:
failed += 1 failed += 1
record_activity( record_activity(

View File

@@ -1,3 +1,5 @@
"""Agent lifecycle, listing, heartbeat, and deletion API endpoints."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@@ -5,14 +7,12 @@ import json
import re import re
from collections.abc import AsyncIterator, Sequence from collections.abc import AsyncIterator, Sequence
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, cast from typing import TYPE_CHECKING, Any, cast
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, or_ from sqlalchemy import asc, or_
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from app.api.deps import ActorContext, require_admin_or_agent, require_org_admin from app.api.deps import ActorContext, require_admin_or_agent, require_org_admin
@@ -23,14 +23,17 @@ from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.activity_events import ActivityEvent from app.models.activity_events import ActivityEvent
from app.models.agents import Agent from app.models.agents import Agent
from app.models.boards import Board from app.models.boards import Board
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.models.organizations import Organization from app.models.organizations import Organization
from app.models.tasks import Task from app.models.tasks import Task
from app.models.users import User
from app.schemas.agents import ( from app.schemas.agents import (
AgentCreate, AgentCreate,
AgentHeartbeat, AgentHeartbeat,
@@ -56,10 +59,23 @@ from app.services.organizations import (
require_board_access, require_board_access,
) )
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.users import User
router = APIRouter(prefix="/agents", tags=["agents"]) router = APIRouter(prefix="/agents", tags=["agents"])
OFFLINE_AFTER = timedelta(minutes=10) OFFLINE_AFTER = timedelta(minutes=10)
AGENT_SESSION_PREFIX = "agent" AGENT_SESSION_PREFIX = "agent"
BOARD_ID_QUERY = Query(default=None)
GATEWAY_ID_QUERY = Query(default=None)
SINCE_QUERY = Query(default=None)
SESSION_DEP = Depends(get_session)
ORG_ADMIN_DEP = Depends(require_org_admin)
ACTOR_DEP = Depends(require_admin_or_agent)
AUTH_DEP = Depends(get_auth_context)
def _parse_since(value: str | None) -> datetime | None: def _parse_since(value: str | None) -> datetime | None:
@@ -111,14 +127,16 @@ async def _require_board(
) )
board = await Board.objects.by_id(board_id).first(session) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found",
)
if user is not None: if user is not None:
await require_board_access(session, user=user, board=board, write=write) await require_board_access(session, user=user, board=board, write=write)
return board return board
async def _require_gateway( async def _require_gateway(
session: AsyncSession, board: Board session: AsyncSession, board: Board,
) -> tuple[Gateway, GatewayClientConfig]: ) -> tuple[Gateway, GatewayClientConfig]:
if not board.gateway_id: if not board.gateway_id:
raise HTTPException( raise HTTPException(
@@ -169,16 +187,20 @@ async def _get_gateway_main_session_keys(session: AsyncSession) -> set[str]:
def _is_gateway_main(agent: Agent, main_session_keys: set[str]) -> bool: def _is_gateway_main(agent: Agent, main_session_keys: set[str]) -> bool:
return bool(agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys) return bool(
agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys,
)
def _to_agent_read(agent: Agent, main_session_keys: set[str]) -> AgentRead: def _to_agent_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
model = AgentRead.model_validate(agent, from_attributes=True) model = AgentRead.model_validate(agent, from_attributes=True)
return model.model_copy(update={"is_gateway_main": _is_gateway_main(agent, main_session_keys)}) return model.model_copy(
update={"is_gateway_main": _is_gateway_main(agent, main_session_keys)},
)
async def _find_gateway_for_main_session( async def _find_gateway_for_main_session(
session: AsyncSession, session_key: str | None session: AsyncSession, session_key: str | None,
) -> Gateway | None: ) -> Gateway | None:
if not session_key: if not session_key:
return None return None
@@ -210,7 +232,9 @@ def _with_computed_status(agent: Agent) -> Agent:
def _serialize_agent(agent: Agent, main_session_keys: set[str]) -> dict[str, object]: def _serialize_agent(agent: Agent, main_session_keys: set[str]) -> dict[str, object]:
return _to_agent_read(_with_computed_status(agent), main_session_keys).model_dump(mode="json") return _to_agent_read(_with_computed_status(agent), main_session_keys).model_dump(
mode="json",
)
async def _fetch_agent_events( async def _fetch_agent_events(
@@ -225,18 +249,22 @@ async def _fetch_agent_events(
or_( or_(
col(Agent.updated_at) >= since, col(Agent.updated_at) >= since,
col(Agent.last_seen_at) >= since, col(Agent.last_seen_at) >= since,
) ),
).order_by(asc(col(Agent.updated_at))) ).order_by(asc(col(Agent.updated_at)))
return list(await session.exec(statement)) return list(await session.exec(statement))
async def _require_user_context(session: AsyncSession, user: User | None) -> OrganizationContext: async def _require_user_context(
session: AsyncSession, user: User | None,
) -> OrganizationContext:
if user is None: if user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
member = await get_active_membership(session, user) member = await get_active_membership(session, user)
if member is None: if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
organization = await Organization.objects.by_id(member.organization_id).first(session) organization = await Organization.objects.by_id(member.organization_id).first(
session,
)
if organization is None: if organization is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return OrganizationContext(organization=organization, member=member) return OrganizationContext(organization=organization, member=member)
@@ -252,7 +280,9 @@ async def _require_agent_access(
if agent.board_id is None: if agent.board_id is None:
if not is_org_admin(ctx.member): if not is_org_admin(ctx.member):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
gateway = await _find_gateway_for_main_session(session, agent.openclaw_session_id) gateway = await _find_gateway_for_main_session(
session, agent.openclaw_session_id,
)
if gateway is None or gateway.organization_id != ctx.organization.id: if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return return
@@ -274,7 +304,7 @@ def _record_heartbeat(session: AsyncSession, agent: Agent) -> None:
def _record_instruction_failure( def _record_instruction_failure(
session: AsyncSession, agent: Agent, error: str, action: str session: AsyncSession, agent: Agent, error: str, action: str,
) -> None: ) -> None:
action_label = action.replace("_", " ").capitalize() action_label = action.replace("_", " ").capitalize()
record_activity( record_activity(
@@ -286,7 +316,7 @@ def _record_instruction_failure(
async def _send_wakeup_message( async def _send_wakeup_message(
agent: Agent, config: GatewayClientConfig, verb: str = "provisioned" agent: Agent, config: GatewayClientConfig, verb: str = "provisioned",
) -> None: ) -> None:
session_key = agent.openclaw_session_id or _build_session_key(agent.name) session_key = agent.openclaw_session_id or _build_session_key(agent.name)
await ensure_session(session_key, config=config, label=agent.name) await ensure_session(session_key, config=config, label=agent.name)
@@ -300,11 +330,12 @@ async def _send_wakeup_message(
@router.get("", response_model=DefaultLimitOffsetPage[AgentRead]) @router.get("", response_model=DefaultLimitOffsetPage[AgentRead])
async def list_agents( async def list_agents(
board_id: UUID | None = Query(default=None), board_id: UUID | None = BOARD_ID_QUERY,
gateway_id: UUID | None = Query(default=None), gateway_id: UUID | None = GATEWAY_ID_QUERY,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> DefaultLimitOffsetPage[AgentRead]: ) -> DefaultLimitOffsetPage[AgentRead]:
"""List agents visible to the active organization admin."""
main_session_keys = await _get_gateway_main_session_keys(session) main_session_keys = await _get_gateway_main_session_keys(session)
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)
if board_id is not None and board_id not in set(board_ids): if board_id is not None and board_id not in set(board_ids):
@@ -315,9 +346,11 @@ async def list_agents(
base_filter: ColumnElement[bool] = col(Agent.board_id).in_(board_ids) base_filter: ColumnElement[bool] = col(Agent.board_id).in_(board_ids)
if is_org_admin(ctx.member): if is_org_admin(ctx.member):
gateway_keys = select(Gateway.main_session_key).where( gateway_keys = select(Gateway.main_session_key).where(
col(Gateway.organization_id) == ctx.organization.id col(Gateway.organization_id) == ctx.organization.id,
)
base_filter = or_(
base_filter, col(Agent.openclaw_session_id).in_(gateway_keys),
) )
base_filter = or_(base_filter, col(Agent.openclaw_session_id).in_(gateway_keys))
statement = select(Agent).where(base_filter) statement = select(Agent).where(base_filter)
if board_id is not None: if board_id is not None:
statement = statement.where(col(Agent.board_id) == board_id) statement = statement.where(col(Agent.board_id) == board_id)
@@ -326,13 +359,16 @@ async def list_agents(
if gateway is None or gateway.organization_id != ctx.organization.id: if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
statement = statement.join(Board, col(Agent.board_id) == col(Board.id)).where( statement = statement.join(Board, col(Agent.board_id) == col(Board.id)).where(
col(Board.gateway_id) == gateway_id col(Board.gateway_id) == gateway_id,
) )
statement = statement.order_by(col(Agent.created_at).desc()) statement = statement.order_by(col(Agent.created_at).desc())
def _transform(items: Sequence[Any]) -> Sequence[Any]: def _transform(items: Sequence[Any]) -> Sequence[Any]:
agents = cast(Sequence[Agent], items) agents = cast(Sequence[Agent], items)
return [_to_agent_read(_with_computed_status(agent), main_session_keys) for agent in agents] return [
_to_agent_read(_with_computed_status(agent), main_session_keys)
for agent in agents
]
return await paginate(session, statement, transformer=_transform) return await paginate(session, statement, transformer=_transform)
@@ -340,11 +376,12 @@ async def list_agents(
@router.get("/stream") @router.get("/stream")
async def stream_agents( async def stream_agents(
request: Request, request: Request,
board_id: UUID | None = Query(default=None), board_id: UUID | None = BOARD_ID_QUERY,
since: str | None = Query(default=None), since: str | None = SINCE_QUERY,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> EventSourceResponse: ) -> EventSourceResponse:
"""Stream agent updates as SSE events."""
since_dt = _parse_since(since) or utcnow() since_dt = _parse_since(since) or utcnow()
last_seen = since_dt last_seen = since_dt
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)
@@ -359,14 +396,20 @@ async def stream_agents(
break break
async with async_session_maker() as stream_session: async with async_session_maker() as stream_session:
if board_id is not None: if board_id is not None:
agents = await _fetch_agent_events(stream_session, board_id, last_seen) agents = await _fetch_agent_events(
stream_session, board_id, last_seen,
)
elif allowed_ids: elif allowed_ids:
agents = await _fetch_agent_events(stream_session, None, last_seen) agents = await _fetch_agent_events(stream_session, None, last_seen)
agents = [agent for agent in agents if agent.board_id in allowed_ids] agents = [
agent for agent in agents if agent.board_id in allowed_ids
]
else: else:
agents = [] agents = []
main_session_keys = ( main_session_keys = (
await _get_gateway_main_session_keys(stream_session) if agents else set() await _get_gateway_main_session_keys(stream_session)
if agents
else set()
) )
for agent in agents: for agent in agents:
updated_at = agent.updated_at or agent.last_seen_at or utcnow() updated_at = agent.updated_at or agent.last_seen_at or utcnow()
@@ -379,11 +422,12 @@ async def stream_agents(
@router.post("", response_model=AgentRead) @router.post("", response_model=AgentRead)
async def create_agent( async def create_agent( # noqa: C901, PLR0912, PLR0915
payload: AgentCreate, payload: AgentCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = ACTOR_DEP,
) -> AgentRead: ) -> AgentRead:
"""Create and provision an agent."""
if actor.actor_type == "user": if actor.actor_type == "user":
ctx = await _require_user_context(session, actor.user) ctx = await _require_user_context(session, actor.user)
if not is_org_admin(ctx.member): if not is_org_admin(ctx.member):
@@ -404,7 +448,9 @@ async def create_agent(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Board leads can only create agents in their own board", detail="Board leads can only create agents in their own board",
) )
payload = AgentCreate(**{**payload.model_dump(), "board_id": actor.agent.board_id}) payload = AgentCreate(
**{**payload.model_dump(), "board_id": actor.agent.board_id},
)
board = await _require_board( board = await _require_board(
session, session,
@@ -420,7 +466,7 @@ async def create_agent(
await session.exec( await session.exec(
select(Agent) select(Agent)
.where(Agent.board_id == board.id) .where(Agent.board_id == board.id)
.where(col(Agent.name).ilike(requested_name)) .where(col(Agent.name).ilike(requested_name)),
) )
).first() ).first()
if existing: if existing:
@@ -428,20 +474,23 @@ async def create_agent(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail="An agent with this name already exists on this board.", detail="An agent with this name already exists on this board.",
) )
# Prevent OpenClaw session/workspace collisions by enforcing uniqueness within # Prevent session/workspace collisions inside the gateway workspace.
# the gateway workspace too (agents on other boards share the same gateway root). # Agents on different boards can still share one gateway root.
existing_gateway = ( existing_gateway = (
await session.exec( await session.exec(
select(Agent) select(Agent)
.join(Board, col(Agent.board_id) == col(Board.id)) .join(Board, col(Agent.board_id) == col(Board.id))
.where(col(Board.gateway_id) == gateway.id) .where(col(Board.gateway_id) == gateway.id)
.where(col(Agent.name).ilike(requested_name)) .where(col(Agent.name).ilike(requested_name)),
) )
).first() ).first()
if existing_gateway: if existing_gateway:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail="An agent with this name already exists in this gateway workspace.", detail=(
"An agent with this name already exists in this gateway "
"workspace."
),
) )
desired_session_key = _build_session_key(requested_name) desired_session_key = _build_session_key(requested_name)
existing_session_key = ( existing_session_key = (
@@ -449,13 +498,16 @@ async def create_agent(
select(Agent) select(Agent)
.join(Board, col(Agent.board_id) == col(Board.id)) .join(Board, col(Agent.board_id) == col(Board.id))
.where(col(Board.gateway_id) == gateway.id) .where(col(Board.gateway_id) == gateway.id)
.where(col(Agent.openclaw_session_id) == desired_session_key) .where(col(Agent.openclaw_session_id) == desired_session_key),
) )
).first() ).first()
if existing_session_key: if existing_session_key:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail="This agent name would collide with an existing workspace session key. Pick a different name.", detail=(
"This agent name would collide with an existing workspace "
"session key. Pick a different name."
),
) )
agent = Agent.model_validate(data) agent = Agent.model_validate(data)
agent.status = "provisioning" agent.status = "provisioning"
@@ -465,7 +517,9 @@ async def create_agent(
agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy() agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy()
agent.provision_requested_at = utcnow() agent.provision_requested_at = utcnow()
agent.provision_action = "provision" agent.provision_action = "provision"
session_key, session_error = await _ensure_gateway_session(agent.name, client_config) session_key, session_error = await _ensure_gateway_session(
agent.name, client_config,
)
agent.openclaw_session_id = session_key agent.openclaw_session_id = session_key
session.add(agent) session.add(agent)
await session.commit() await session.commit()
@@ -527,9 +581,10 @@ async def create_agent(
@router.get("/{agent_id}", response_model=AgentRead) @router.get("/{agent_id}", response_model=AgentRead)
async def get_agent( async def get_agent(
agent_id: str, agent_id: str,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> AgentRead: ) -> AgentRead:
"""Get a single agent by id."""
agent = await Agent.objects.by_id(agent_id).first(session) agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None: if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -539,14 +594,16 @@ async def get_agent(
@router.patch("/{agent_id}", response_model=AgentRead) @router.patch("/{agent_id}", response_model=AgentRead)
async def update_agent( async def update_agent( # noqa: C901, PLR0912, PLR0913, PLR0915
agent_id: str, agent_id: str,
payload: AgentUpdate, payload: AgentUpdate,
*,
force: bool = False, force: bool = False,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> AgentRead: ) -> AgentRead:
"""Update agent metadata and optionally reprovision."""
agent = await Agent.objects.by_id(agent_id).first(session) agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None: if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -564,12 +621,16 @@ async def update_agent(
new_board = await _require_board(session, updates["board_id"]) new_board = await _require_board(session, updates["board_id"])
if new_board.organization_id != ctx.organization.id: if new_board.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if not await has_board_access(session, member=ctx.member, board=new_board, write=True): if not await has_board_access(
session, member=ctx.member, board=new_board, write=True,
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if not updates and not force and make_main is None: if not updates and not force and make_main is None:
main_session_keys = await _get_gateway_main_session_keys(session) main_session_keys = await _get_gateway_main_session_keys(session)
return _to_agent_read(_with_computed_status(agent), main_session_keys) return _to_agent_read(_with_computed_status(agent), main_session_keys)
main_gateway = await _find_gateway_for_main_session(session, agent.openclaw_session_id) main_gateway = await _find_gateway_for_main_session(
session, agent.openclaw_session_id,
)
gateway_for_main: Gateway | None = None gateway_for_main: Gateway | None = None
if make_main is True: if make_main is True:
board_source = updates.get("board_id") or agent.board_id board_source = updates.get("board_id") or agent.board_id
@@ -723,9 +784,10 @@ async def update_agent(
async def heartbeat_agent( async def heartbeat_agent(
agent_id: str, agent_id: str,
payload: AgentHeartbeat, payload: AgentHeartbeat,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = ACTOR_DEP,
) -> AgentRead: ) -> AgentRead:
"""Record a heartbeat for a specific agent."""
agent = await Agent.objects.by_id(agent_id).first(session) agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None: if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -751,12 +813,14 @@ async def heartbeat_agent(
@router.post("/heartbeat", response_model=AgentRead) @router.post("/heartbeat", response_model=AgentRead)
async def heartbeat_or_create_agent( async def heartbeat_or_create_agent( # noqa: C901, PLR0912, PLR0915
payload: AgentHeartbeatCreate, payload: AgentHeartbeatCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = ACTOR_DEP,
) -> AgentRead: ) -> AgentRead:
# Agent tokens must heartbeat their authenticated agent record. Names are not unique. """Heartbeat an existing agent or create/provision one if needed."""
# Agent tokens must heartbeat their authenticated agent record.
# Names are not unique.
if actor.actor_type == "agent" and actor.agent: if actor.actor_type == "agent" and actor.agent:
return await heartbeat_agent( return await heartbeat_agent(
agent_id=str(actor.agent.id), agent_id=str(actor.agent.id),
@@ -793,7 +857,9 @@ async def heartbeat_or_create_agent(
agent.agent_token_hash = hash_agent_token(raw_token) agent.agent_token_hash = hash_agent_token(raw_token)
agent.provision_requested_at = utcnow() agent.provision_requested_at = utcnow()
agent.provision_action = "provision" agent.provision_action = "provision"
session_key, session_error = await _ensure_gateway_session(agent.name, client_config) session_key, session_error = await _ensure_gateway_session(
agent.name, client_config,
)
agent.openclaw_session_id = session_key agent.openclaw_session_id = session_key
session.add(agent) session.add(agent)
await session.commit() await session.commit()
@@ -814,7 +880,9 @@ async def heartbeat_or_create_agent(
) )
await session.commit() await session.commit()
try: try:
await provision_agent(agent, board, gateway, raw_token, actor.user, action="provision") await provision_agent(
agent, board, gateway, raw_token, actor.user, action="provision",
)
await _send_wakeup_message(agent, client_config, verb="provisioned") await _send_wakeup_message(agent, client_config, verb="provisioned")
agent.provision_confirm_token_hash = None agent.provision_confirm_token_hash = None
agent.provision_requested_at = None agent.provision_requested_at = None
@@ -864,7 +932,7 @@ async def heartbeat_or_create_agent(
) )
gateway, client_config = await _require_gateway(session, board) gateway, client_config = await _require_gateway(session, board)
await provision_agent( await provision_agent(
agent, board, gateway, raw_token, actor.user, action="provision" agent, board, gateway, raw_token, actor.user, action="provision",
) )
await _send_wakeup_message(agent, client_config, verb="provisioned") await _send_wakeup_message(agent, client_config, verb="provisioned")
agent.provision_confirm_token_hash = None agent.provision_confirm_token_hash = None
@@ -903,7 +971,9 @@ async def heartbeat_or_create_agent(
write=actor.actor_type == "user", write=actor.actor_type == "user",
) )
gateway, client_config = await _require_gateway(session, board) gateway, client_config = await _require_gateway(session, board)
session_key, session_error = await _ensure_gateway_session(agent.name, client_config) session_key, session_error = await _ensure_gateway_session(
agent.name, client_config,
)
agent.openclaw_session_id = session_key agent.openclaw_session_id = session_key
if session_error: if session_error:
record_activity( record_activity(
@@ -937,15 +1007,18 @@ async def heartbeat_or_create_agent(
@router.delete("/{agent_id}", response_model=OkResponse) @router.delete("/{agent_id}", response_model=OkResponse)
async def delete_agent( async def delete_agent(
agent_id: str, agent_id: str,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OkResponse: ) -> OkResponse:
"""Delete an agent and clean related task state."""
agent = await Agent.objects.by_id(agent_id).first(session) agent = await Agent.objects.by_id(agent_id).first(session)
if agent is None: if agent is None:
return OkResponse() return OkResponse()
await _require_agent_access(session, agent=agent, ctx=ctx, write=True) await _require_agent_access(session, agent=agent, ctx=ctx, write=True)
board = await _require_board(session, str(agent.board_id) if agent.board_id else None) board = await _require_board(
session, str(agent.board_id) if agent.board_id else None,
)
gateway, client_config = await _require_gateway(session, board) gateway, client_config = await _require_gateway(session, board)
try: try:
workspace_path = await cleanup_agent(agent, gateway) workspace_path = await cleanup_agent(agent, gateway)
@@ -970,7 +1043,7 @@ async def delete_agent(
message=f"Deleted agent {agent.name}.", message=f"Deleted agent {agent.name}.",
agent_id=None, agent_id=None,
) )
now = datetime.now() now = utcnow()
await crud.update_where( await crud.update_where(
session, session,
Task, Task,

View File

@@ -1,3 +1,5 @@
"""Authentication bootstrap endpoints for the Mission Control API."""
from __future__ import annotations from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
@@ -6,10 +8,12 @@ from app.core.auth import AuthContext, get_auth_context
from app.schemas.users import UserRead from app.schemas.users import UserRead
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
AUTH_CONTEXT_DEP = Depends(get_auth_context)
@router.post("/bootstrap", response_model=UserRead) @router.post("/bootstrap", response_model=UserRead)
async def bootstrap_user(auth: AuthContext = Depends(get_auth_context)) -> UserRead: async def bootstrap_user(auth: AuthContext = AUTH_CONTEXT_DEP) -> UserRead:
"""Return the authenticated user profile from token claims."""
if auth.actor_type != "user" or auth.user is None: if auth.actor_type != "user" or auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return UserRead.model_validate(auth.user) return UserRead.model_validate(auth.user)

View File

@@ -1,13 +1,16 @@
"""Board onboarding endpoints for user/agent collaboration."""
# ruff: noqa: E501
from __future__ import annotations from __future__ import annotations
import logging import logging
import re import re
from typing import TYPE_CHECKING
from uuid import uuid4 from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import ValidationError from pydantic import ValidationError
from sqlmodel import col from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import ( from app.api.deps import (
ActorContext, ActorContext,
@@ -18,15 +21,17 @@ from app.api.deps import (
require_admin_or_agent, require_admin_or_agent,
) )
from app.core.agent_tokens import generate_agent_token, hash_agent_token from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.auth import AuthContext
from app.core.config import settings from app.core.config import settings
from app.core.time import utcnow from app.core.time import utcnow
from app.db.session import get_session from app.db.session import get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent from app.models.agents import Agent
from app.models.board_onboarding import BoardOnboardingSession from app.models.board_onboarding import BoardOnboardingSession
from app.models.boards import Board
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.schemas.board_onboarding import ( from app.schemas.board_onboarding import (
BoardOnboardingAgentComplete, BoardOnboardingAgentComplete,
@@ -41,12 +46,24 @@ from app.schemas.board_onboarding import (
from app.schemas.boards import BoardRead from app.schemas.boards import BoardRead
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.auth import AuthContext
from app.models.boards import Board
router = APIRouter(prefix="/boards/{board_id}/onboarding", tags=["board-onboarding"]) router = APIRouter(prefix="/boards/{board_id}/onboarding", tags=["board-onboarding"])
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BOARD_USER_READ_DEP = Depends(get_board_for_user_read)
BOARD_USER_WRITE_DEP = Depends(get_board_for_user_write)
BOARD_OR_404_DEP = Depends(get_board_or_404)
SESSION_DEP = Depends(get_session)
ACTOR_DEP = Depends(require_admin_or_agent)
ADMIN_AUTH_DEP = Depends(require_admin_auth)
async def _gateway_config( async def _gateway_config(
session: AsyncSession, board: Board session: AsyncSession, board: Board,
) -> tuple[Gateway, GatewayClientConfig]: ) -> tuple[Gateway, GatewayClientConfig]:
if not board.gateway_id: if not board.gateway_id:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
@@ -61,7 +78,7 @@ def _build_session_key(agent_name: str) -> str:
return f"agent:{slug or uuid4().hex}:main" return f"agent:{slug or uuid4().hex}:main"
def _lead_agent_name(board: Board) -> str: def _lead_agent_name(_board: Board) -> str:
return "Lead Agent" return "Lead Agent"
@@ -69,7 +86,7 @@ def _lead_session_key(board: Board) -> str:
return f"agent:lead-{board.id}:main" return f"agent:lead-{board.id}:main"
async def _ensure_lead_agent( async def _ensure_lead_agent( # noqa: PLR0913
session: AsyncSession, session: AsyncSession,
board: Board, board: Board,
gateway: Gateway, gateway: Gateway,
@@ -100,7 +117,11 @@ async def _ensure_lead_agent(
} }
if identity_profile: if identity_profile:
merged_identity_profile.update( merged_identity_profile.update(
{key: value.strip() for key, value in identity_profile.items() if value.strip()} {
key: value.strip()
for key, value in identity_profile.items()
if value.strip()
},
) )
agent = Agent( agent = Agent(
@@ -121,7 +142,9 @@ async def _ensure_lead_agent(
await session.refresh(agent) await session.refresh(agent)
try: try:
await provision_agent(agent, board, gateway, raw_token, auth.user, action="provision") await provision_agent(
agent, board, gateway, raw_token, auth.user, action="provision",
)
await ensure_session(agent.openclaw_session_id, config=config, label=agent.name) await ensure_session(agent.openclaw_session_id, config=config, label=agent.name)
await send_message( await send_message(
( (
@@ -141,9 +164,10 @@ async def _ensure_lead_agent(
@router.get("", response_model=BoardOnboardingRead) @router.get("", response_model=BoardOnboardingRead)
async def get_onboarding( async def get_onboarding(
board: Board = Depends(get_board_for_user_read), board: Board = BOARD_USER_READ_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> BoardOnboardingSession: ) -> BoardOnboardingSession:
"""Get the latest onboarding session for a board."""
onboarding = ( onboarding = (
await BoardOnboardingSession.objects.filter_by(board_id=board.id) await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.order_by(col(BoardOnboardingSession.updated_at).desc()) .order_by(col(BoardOnboardingSession.updated_at).desc())
@@ -156,10 +180,11 @@ async def get_onboarding(
@router.post("/start", response_model=BoardOnboardingRead) @router.post("/start", response_model=BoardOnboardingRead)
async def start_onboarding( async def start_onboarding(
payload: BoardOnboardingStart, _payload: BoardOnboardingStart,
board: Board = Depends(get_board_for_user_write), board: Board = BOARD_USER_WRITE_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> BoardOnboardingSession: ) -> BoardOnboardingSession:
"""Start onboarding and send instructions to the gateway main agent."""
onboarding = ( onboarding = (
await BoardOnboardingSession.objects.filter_by(board_id=board.id) await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.filter(col(BoardOnboardingSession.status) == "active") .filter(col(BoardOnboardingSession.status) == "active")
@@ -219,15 +244,21 @@ async def start_onboarding(
try: try:
await ensure_session(session_key, config=config, label="Main Agent") await ensure_session(session_key, config=config, label="Main Agent")
await send_message(prompt, session_key=session_key, config=config, deliver=False) await send_message(
prompt, session_key=session_key, config=config, deliver=False,
)
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
onboarding = BoardOnboardingSession( onboarding = BoardOnboardingSession(
board_id=board.id, board_id=board.id,
session_key=session_key, session_key=session_key,
status="active", status="active",
messages=[{"role": "user", "content": prompt, "timestamp": utcnow().isoformat()}], messages=[
{"role": "user", "content": prompt, "timestamp": utcnow().isoformat()},
],
) )
session.add(onboarding) session.add(onboarding)
await session.commit() await session.commit()
@@ -238,9 +269,10 @@ async def start_onboarding(
@router.post("/answer", response_model=BoardOnboardingRead) @router.post("/answer", response_model=BoardOnboardingRead)
async def answer_onboarding( async def answer_onboarding(
payload: BoardOnboardingAnswer, payload: BoardOnboardingAnswer,
board: Board = Depends(get_board_for_user_write), board: Board = BOARD_USER_WRITE_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> BoardOnboardingSession: ) -> BoardOnboardingSession:
"""Send a user onboarding answer to the gateway main agent."""
onboarding = ( onboarding = (
await BoardOnboardingSession.objects.filter_by(board_id=board.id) await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.order_by(col(BoardOnboardingSession.updated_at).desc()) .order_by(col(BoardOnboardingSession.updated_at).desc())
@@ -255,15 +287,22 @@ async def answer_onboarding(
answer_text = f"{payload.answer}: {payload.other_text}" answer_text = f"{payload.answer}: {payload.other_text}"
messages = list(onboarding.messages or []) messages = list(onboarding.messages or [])
messages.append({"role": "user", "content": answer_text, "timestamp": utcnow().isoformat()}) messages.append(
{"role": "user", "content": answer_text, "timestamp": utcnow().isoformat()},
)
try: try:
await ensure_session(onboarding.session_key, config=config, label="Main Agent") await ensure_session(onboarding.session_key, config=config, label="Main Agent")
await send_message( await send_message(
answer_text, session_key=onboarding.session_key, config=config, deliver=False answer_text,
session_key=onboarding.session_key,
config=config,
deliver=False,
) )
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
) from exc
onboarding.messages = messages onboarding.messages = messages
onboarding.updated_at = utcnow() onboarding.updated_at = utcnow()
@@ -276,10 +315,11 @@ async def answer_onboarding(
@router.post("/agent", response_model=BoardOnboardingRead) @router.post("/agent", response_model=BoardOnboardingRead)
async def agent_onboarding_update( async def agent_onboarding_update(
payload: BoardOnboardingAgentUpdate, payload: BoardOnboardingAgentUpdate,
board: Board = Depends(get_board_or_404), board: Board = BOARD_OR_404_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = ACTOR_DEP,
) -> BoardOnboardingSession: ) -> BoardOnboardingSession:
"""Store onboarding updates submitted by the gateway main agent."""
if actor.actor_type != "agent" or actor.agent is None: if actor.actor_type != "agent" or actor.agent is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
agent = actor.agent agent = actor.agent
@@ -288,9 +328,13 @@ async def agent_onboarding_update(
if board.gateway_id: if board.gateway_id:
gateway = await Gateway.objects.by_id(board.gateway_id).first(session) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
if gateway and gateway.main_session_key and agent.openclaw_session_id: if (
if agent.openclaw_session_id != gateway.main_session_key: gateway
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) and gateway.main_session_key
and agent.openclaw_session_id
and agent.openclaw_session_id != gateway.main_session_key
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
onboarding = ( onboarding = (
await BoardOnboardingSession.objects.filter_by(board_id=board.id) await BoardOnboardingSession.objects.filter_by(board_id=board.id)
@@ -315,9 +359,13 @@ async def agent_onboarding_update(
if isinstance(payload, BoardOnboardingAgentComplete): if isinstance(payload, BoardOnboardingAgentComplete):
onboarding.draft_goal = payload_data onboarding.draft_goal = payload_data
onboarding.status = "completed" onboarding.status = "completed"
messages.append({"role": "assistant", "content": payload_text, "timestamp": now}) messages.append(
{"role": "assistant", "content": payload_text, "timestamp": now},
)
else: else:
messages.append({"role": "assistant", "content": payload_text, "timestamp": now}) messages.append(
{"role": "assistant", "content": payload_text, "timestamp": now},
)
onboarding.messages = messages onboarding.messages = messages
onboarding.updated_at = utcnow() onboarding.updated_at = utcnow()
@@ -334,12 +382,13 @@ async def agent_onboarding_update(
@router.post("/confirm", response_model=BoardRead) @router.post("/confirm", response_model=BoardRead)
async def confirm_onboarding( async def confirm_onboarding( # noqa: C901, PLR0912, PLR0915
payload: BoardOnboardingConfirm, payload: BoardOnboardingConfirm,
board: Board = Depends(get_board_for_user_write), board: Board = BOARD_USER_WRITE_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(require_admin_auth), auth: AuthContext = ADMIN_AUTH_DEP,
) -> Board: ) -> Board:
"""Confirm onboarding results and provision the board lead agent."""
onboarding = ( onboarding = (
await BoardOnboardingSession.objects.filter_by(board_id=board.id) await BoardOnboardingSession.objects.filter_by(board_id=board.id)
.order_by(col(BoardOnboardingSession.updated_at).desc()) .order_by(col(BoardOnboardingSession.updated_at).desc())
@@ -409,7 +458,9 @@ async def confirm_onboarding(
if lead_agent.update_cadence: if lead_agent.update_cadence:
lead_identity_profile["update_cadence"] = lead_agent.update_cadence lead_identity_profile["update_cadence"] = lead_agent.update_cadence
if lead_agent.custom_instructions: if lead_agent.custom_instructions:
lead_identity_profile["custom_instructions"] = lead_agent.custom_instructions lead_identity_profile["custom_instructions"] = (
lead_agent.custom_instructions
)
gateway, config = await _gateway_config(session, board) gateway, config = await _gateway_config(session, board)
session.add(board) session.add(board)

View File

@@ -1,12 +1,14 @@
"""Board CRUD and snapshot endpoints."""
from __future__ import annotations from __future__ import annotations
import re import re
from typing import TYPE_CHECKING
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import ( from app.api.deps import (
get_board_for_actor_read, get_board_for_actor_read,
@@ -47,9 +49,23 @@ from app.services.board_group_snapshot import build_board_group_snapshot
from app.services.board_snapshot import build_board_snapshot from app.services.board_snapshot import build_board_snapshot
from app.services.organizations import OrganizationContext, board_access_filter from app.services.organizations import OrganizationContext, board_access_filter
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(prefix="/boards", tags=["boards"]) router = APIRouter(prefix="/boards", tags=["boards"])
AGENT_SESSION_PREFIX = "agent" AGENT_SESSION_PREFIX = "agent"
SESSION_DEP = Depends(get_session)
ORG_ADMIN_DEP = Depends(require_org_admin)
ORG_MEMBER_DEP = Depends(require_org_member)
BOARD_USER_READ_DEP = Depends(get_board_for_user_read)
BOARD_USER_WRITE_DEP = Depends(get_board_for_user_write)
BOARD_ACTOR_READ_DEP = Depends(get_board_for_actor_read)
GATEWAY_ID_QUERY = Query(default=None)
BOARD_GROUP_ID_QUERY = Query(default=None)
INCLUDE_SELF_QUERY = Query(default=False)
INCLUDE_DONE_QUERY = Query(default=False)
PER_BOARD_TASK_LIMIT_QUERY = Query(default=5, ge=0, le=100)
def _slugify(value: str) -> str: def _slugify(value: str) -> str:
@@ -83,10 +99,12 @@ async def _require_gateway(
async def _require_gateway_for_create( async def _require_gateway_for_create(
payload: BoardCreate, payload: BoardCreate,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> Gateway: ) -> Gateway:
return await _require_gateway(session, payload.gateway_id, organization_id=ctx.organization.id) return await _require_gateway(
session, payload.gateway_id, organization_id=ctx.organization.id,
)
async def _require_board_group( async def _require_board_group(
@@ -111,8 +129,8 @@ async def _require_board_group(
async def _require_board_group_for_create( async def _require_board_group_for_create(
payload: BoardCreate, payload: BoardCreate,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> BoardGroup | None: ) -> BoardGroup | None:
if payload.board_group_id is None: if payload.board_group_id is None:
return None return None
@@ -123,6 +141,10 @@ async def _require_board_group_for_create(
) )
GATEWAY_CREATE_DEP = Depends(_require_gateway_for_create)
BOARD_GROUP_CREATE_DEP = Depends(_require_board_group_for_create)
async def _apply_board_update( async def _apply_board_update(
*, *,
payload: BoardUpdate, payload: BoardUpdate,
@@ -132,7 +154,7 @@ async def _apply_board_update(
updates = payload.model_dump(exclude_unset=True) updates = payload.model_dump(exclude_unset=True)
if "gateway_id" in updates: if "gateway_id" in updates:
await _require_gateway( await _require_gateway(
session, updates["gateway_id"], organization_id=board.organization_id session, updates["gateway_id"], organization_id=board.organization_id,
) )
if "board_group_id" in updates and updates["board_group_id"] is not None: if "board_group_id" in updates and updates["board_group_id"] is not None:
await _require_board_group( await _require_board_group(
@@ -141,13 +163,15 @@ async def _apply_board_update(
organization_id=board.organization_id, organization_id=board.organization_id,
) )
crud.apply_updates(board, updates) crud.apply_updates(board, updates)
if updates.get("board_type") == "goal": if (
updates.get("board_type") == "goal"
and (not board.objective or not board.success_metrics)
):
# Validate only when explicitly switching to goal boards. # Validate only when explicitly switching to goal boards.
if not board.objective or not board.success_metrics: raise HTTPException(
raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Goal boards require objective and success_metrics",
detail="Goal boards require objective and success_metrics", )
)
if not board.gateway_id: if not board.gateway_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -158,7 +182,7 @@ async def _apply_board_update(
async def _board_gateway( async def _board_gateway(
session: AsyncSession, board: Board session: AsyncSession, board: Board,
) -> tuple[Gateway | None, GatewayClientConfig | None]: ) -> tuple[Gateway | None, GatewayClientConfig | None]:
if not board.gateway_id: if not board.gateway_id:
return None, None return None, None
@@ -218,28 +242,32 @@ async def _cleanup_agent_on_gateway(
@router.get("", response_model=DefaultLimitOffsetPage[BoardRead]) @router.get("", response_model=DefaultLimitOffsetPage[BoardRead])
async def list_boards( async def list_boards(
gateway_id: UUID | None = Query(default=None), gateway_id: UUID | None = GATEWAY_ID_QUERY,
board_group_id: UUID | None = Query(default=None), board_group_id: UUID | None = BOARD_GROUP_ID_QUERY,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[BoardRead]: ) -> DefaultLimitOffsetPage[BoardRead]:
"""List boards visible to the current organization member."""
statement = select(Board).where(board_access_filter(ctx.member, write=False)) statement = select(Board).where(board_access_filter(ctx.member, write=False))
if gateway_id is not None: if gateway_id is not None:
statement = statement.where(col(Board.gateway_id) == gateway_id) statement = statement.where(col(Board.gateway_id) == gateway_id)
if board_group_id is not None: if board_group_id is not None:
statement = statement.where(col(Board.board_group_id) == board_group_id) statement = statement.where(col(Board.board_group_id) == board_group_id)
statement = statement.order_by(func.lower(col(Board.name)).asc(), col(Board.created_at).desc()) statement = statement.order_by(
func.lower(col(Board.name)).asc(), col(Board.created_at).desc(),
)
return await paginate(session, statement) return await paginate(session, statement)
@router.post("", response_model=BoardRead) @router.post("", response_model=BoardRead)
async def create_board( async def create_board(
payload: BoardCreate, payload: BoardCreate,
_gateway: Gateway = Depends(_require_gateway_for_create), _gateway: Gateway = GATEWAY_CREATE_DEP,
_board_group: BoardGroup | None = Depends(_require_board_group_for_create), _board_group: BoardGroup | None = BOARD_GROUP_CREATE_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> Board: ) -> Board:
"""Create a board in the active organization."""
data = payload.model_dump() data = payload.model_dump()
data["organization_id"] = ctx.organization.id data["organization_id"] = ctx.organization.id
return await crud.create(session, Board, **data) return await crud.create(session, Board, **data)
@@ -247,27 +275,31 @@ async def create_board(
@router.get("/{board_id}", response_model=BoardRead) @router.get("/{board_id}", response_model=BoardRead)
def get_board( def get_board(
board: Board = Depends(get_board_for_user_read), board: Board = BOARD_USER_READ_DEP,
) -> Board: ) -> Board:
"""Get a board by id."""
return board return board
@router.get("/{board_id}/snapshot", response_model=BoardSnapshot) @router.get("/{board_id}/snapshot", response_model=BoardSnapshot)
async def get_board_snapshot( async def get_board_snapshot(
board: Board = Depends(get_board_for_actor_read), board: Board = BOARD_ACTOR_READ_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> BoardSnapshot: ) -> BoardSnapshot:
"""Get a board snapshot view model."""
return await build_board_snapshot(session, board) return await build_board_snapshot(session, board)
@router.get("/{board_id}/group-snapshot", response_model=BoardGroupSnapshot) @router.get("/{board_id}/group-snapshot", response_model=BoardGroupSnapshot)
async def get_board_group_snapshot( async def get_board_group_snapshot(
include_self: bool = Query(default=False), *,
include_done: bool = Query(default=False), include_self: bool = INCLUDE_SELF_QUERY,
per_board_task_limit: int = Query(default=5, ge=0, le=100), include_done: bool = INCLUDE_DONE_QUERY,
board: Board = Depends(get_board_for_actor_read), per_board_task_limit: int = PER_BOARD_TASK_LIMIT_QUERY,
session: AsyncSession = Depends(get_session), board: Board = BOARD_ACTOR_READ_DEP,
session: AsyncSession = SESSION_DEP,
) -> BoardGroupSnapshot: ) -> BoardGroupSnapshot:
"""Get a grouped snapshot across related boards."""
return await build_board_group_snapshot( return await build_board_group_snapshot(
session, session,
board=board, board=board,
@@ -280,19 +312,23 @@ async def get_board_group_snapshot(
@router.patch("/{board_id}", response_model=BoardRead) @router.patch("/{board_id}", response_model=BoardRead)
async def update_board( async def update_board(
payload: BoardUpdate, payload: BoardUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
board: Board = Depends(get_board_for_user_write), board: Board = BOARD_USER_WRITE_DEP,
) -> Board: ) -> Board:
"""Update mutable board properties."""
return await _apply_board_update(payload=payload, session=session, board=board) return await _apply_board_update(payload=payload, session=session, board=board)
@router.delete("/{board_id}", response_model=OkResponse) @router.delete("/{board_id}", response_model=OkResponse)
async def delete_board( async def delete_board(
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
board: Board = Depends(get_board_for_user_write), board: Board = BOARD_USER_WRITE_DEP,
) -> OkResponse: ) -> OkResponse:
"""Delete a board and all dependent records."""
agents = await Agent.objects.filter_by(board_id=board.id).all(session) agents = await Agent.objects.filter_by(board_id=board.id).all(session)
task_ids = list(await session.exec(select(Task.id).where(Task.board_id == board.id))) task_ids = list(
await session.exec(select(Task.id).where(Task.board_id == board.id)),
)
config, client_config = await _board_gateway(session, board) config, client_config = await _board_gateway(session, board)
if config and client_config: if config and client_config:
@@ -307,20 +343,31 @@ async def delete_board(
if task_ids: if task_ids:
await crud.delete_where( await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False session,
ActivityEvent,
col(ActivityEvent.task_id).in_(task_ids),
commit=False,
) )
await crud.delete_where(session, TaskDependency, col(TaskDependency.board_id) == board.id) await crud.delete_where(
await crud.delete_where(session, TaskFingerprint, col(TaskFingerprint.board_id) == board.id) session, TaskDependency, col(TaskDependency.board_id) == board.id,
)
await crud.delete_where(
session, TaskFingerprint, col(TaskFingerprint.board_id) == board.id,
)
# Approvals can reference tasks and agents, so delete before both. # Approvals can reference tasks and agents, so delete before both.
await crud.delete_where(session, Approval, col(Approval.board_id) == board.id) await crud.delete_where(session, Approval, col(Approval.board_id) == board.id)
await crud.delete_where(session, BoardMemory, col(BoardMemory.board_id) == board.id) await crud.delete_where(session, BoardMemory, col(BoardMemory.board_id) == board.id)
await crud.delete_where( await crud.delete_where(
session, BoardOnboardingSession, col(BoardOnboardingSession.board_id) == board.id session,
BoardOnboardingSession,
col(BoardOnboardingSession.board_id) == board.id,
) )
await crud.delete_where( await crud.delete_where(
session, OrganizationBoardAccess, col(OrganizationBoardAccess.board_id) == board.id session,
OrganizationBoardAccess,
col(OrganizationBoardAccess.board_id) == board.id,
) )
await crud.delete_where( await crud.delete_where(
session, session,
@@ -328,14 +375,17 @@ async def delete_board(
col(OrganizationInviteBoardAccess.board_id) == board.id, col(OrganizationInviteBoardAccess.board_id) == board.id,
) )
# Tasks reference agents (assigned_agent_id) and have dependents (fingerprints/dependencies), so # Tasks reference agents and have dependent records.
# delete tasks before agents. # delete tasks before agents.
await crud.delete_where(session, Task, col(Task.board_id) == board.id) await crud.delete_where(session, Task, col(Task.board_id) == board.id)
if agents: if agents:
agent_ids = [agent.id for agent in agents] agent_ids = [agent.id for agent in agents]
await crud.delete_where( await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False session,
ActivityEvent,
col(ActivityEvent.agent_id).in_(agent_ids),
commit=False,
) )
await crud.delete_where(session, Agent, col(Agent.id).in_(agent_ids)) await crud.delete_where(session, Agent, col(Agent.id).in_(agent_ids))
await session.delete(board) await session.delete(board)

View File

@@ -1,16 +1,17 @@
"""Organization management endpoints and membership/invite flows."""
from __future__ import annotations from __future__ import annotations
import secrets import secrets
from typing import Any, Sequence from typing import TYPE_CHECKING, Any, Sequence
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api.deps import require_org_admin, require_org_member from app.api.deps import require_org_admin, require_org_member
from app.core.auth import AuthContext, get_auth_context from app.core.auth import get_auth_context
from app.core.time import utcnow from app.core.time import utcnow
from app.db import crud from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
@@ -63,10 +64,21 @@ from app.services.organizations import (
set_active_organization, set_active_organization,
) )
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.auth import AuthContext
router = APIRouter(prefix="/organizations", tags=["organizations"]) router = APIRouter(prefix="/organizations", tags=["organizations"])
SESSION_DEP = Depends(get_session)
AUTH_DEP = Depends(get_auth_context)
ORG_MEMBER_DEP = Depends(require_org_member)
ORG_ADMIN_DEP = Depends(require_org_admin)
def _member_to_read(member: OrganizationMember, user: User | None) -> OrganizationMemberRead: def _member_to_read(
member: OrganizationMember, user: User | None,
) -> OrganizationMemberRead:
model = OrganizationMemberRead.model_validate(member, from_attributes=True) model = OrganizationMemberRead.model_validate(member, from_attributes=True)
if user is not None: if user is not None:
model.user = OrganizationUserRead.model_validate(user, from_attributes=True) model.user = OrganizationUserRead.model_validate(user, from_attributes=True)
@@ -100,9 +112,10 @@ async def _require_org_invite(
@router.post("", response_model=OrganizationRead) @router.post("", response_model=OrganizationRead)
async def create_organization( async def create_organization(
payload: OrganizationCreate, payload: OrganizationCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
) -> OrganizationRead: ) -> OrganizationRead:
"""Create an organization and assign the caller as owner."""
if auth.user is None: if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
name = payload.name.strip() name = payload.name.strip()
@@ -110,7 +123,9 @@ async def create_organization(
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
existing = ( existing = (
await session.exec( await session.exec(
select(Organization).where(func.lower(col(Organization.name)) == name.lower()) select(Organization).where(
func.lower(col(Organization.name)) == name.lower(),
),
) )
).first() ).first()
if existing is not None: if existing is not None:
@@ -140,19 +155,25 @@ async def create_organization(
@router.get("/me/list", response_model=list[OrganizationListItem]) @router.get("/me/list", response_model=list[OrganizationListItem])
async def list_my_organizations( async def list_my_organizations(
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
) -> list[OrganizationListItem]: ) -> list[OrganizationListItem]:
"""List organizations where the current user is a member."""
if auth.user is None: if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
await get_active_membership(session, auth.user) await get_active_membership(session, auth.user)
db_user = await User.objects.by_id(auth.user.id).first(session) db_user = await User.objects.by_id(auth.user.id).first(session)
active_id = db_user.active_organization_id if db_user else auth.user.active_organization_id active_id = (
db_user.active_organization_id if db_user else auth.user.active_organization_id
)
statement = ( statement = (
select(Organization, OrganizationMember) select(Organization, OrganizationMember)
.join(OrganizationMember, col(OrganizationMember.organization_id) == col(Organization.id)) .join(
OrganizationMember,
col(OrganizationMember.organization_id) == col(Organization.id),
)
.where(col(OrganizationMember.user_id) == auth.user.id) .where(col(OrganizationMember.user_id) == auth.user.id)
.order_by(func.lower(col(Organization.name)).asc()) .order_by(func.lower(col(Organization.name)).asc())
) )
@@ -171,30 +192,37 @@ async def list_my_organizations(
@router.patch("/me/active", response_model=OrganizationRead) @router.patch("/me/active", response_model=OrganizationRead)
async def set_active_org( async def set_active_org(
payload: OrganizationActiveUpdate, payload: OrganizationActiveUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
) -> OrganizationRead: ) -> OrganizationRead:
"""Set the caller's active organization."""
if auth.user is None: if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
member = await set_active_organization( member = await set_active_organization(
session, user=auth.user, organization_id=payload.organization_id session, user=auth.user, organization_id=payload.organization_id,
)
organization = await Organization.objects.by_id(member.organization_id).first(
session,
) )
organization = await Organization.objects.by_id(member.organization_id).first(session)
if organization is None: if organization is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return OrganizationRead.model_validate(organization, from_attributes=True) return OrganizationRead.model_validate(organization, from_attributes=True)
@router.get("/me", response_model=OrganizationRead) @router.get("/me", response_model=OrganizationRead)
async def get_my_org(ctx: OrganizationContext = Depends(require_org_member)) -> OrganizationRead: async def get_my_org(
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> OrganizationRead:
"""Return the caller's active organization."""
return OrganizationRead.model_validate(ctx.organization, from_attributes=True) return OrganizationRead.model_validate(ctx.organization, from_attributes=True)
@router.delete("/me", response_model=OkResponse) @router.delete("/me", response_model=OkResponse)
async def delete_my_org( async def delete_my_org(
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OkResponse: ) -> OkResponse:
"""Delete the active organization and related entities."""
if ctx.member.role != "owner": if ctx.member.role != "owner":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@@ -206,28 +234,39 @@ async def delete_my_org(
task_ids = select(Task.id).where(col(Task.board_id).in_(board_ids)) task_ids = select(Task.id).where(col(Task.board_id).in_(board_ids))
agent_ids = select(Agent.id).where(col(Agent.board_id).in_(board_ids)) agent_ids = select(Agent.id).where(col(Agent.board_id).in_(board_ids))
member_ids = select(OrganizationMember.id).where( member_ids = select(OrganizationMember.id).where(
col(OrganizationMember.organization_id) == org_id col(OrganizationMember.organization_id) == org_id,
) )
invite_ids = select(OrganizationInvite.id).where( invite_ids = select(OrganizationInvite.id).where(
col(OrganizationInvite.organization_id) == org_id col(OrganizationInvite.organization_id) == org_id,
) )
group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id) group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id)
await crud.delete_where( await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False session,
ActivityEvent,
col(ActivityEvent.agent_id).in_(agent_ids),
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, TaskDependency, col(TaskDependency.board_id).in_(board_ids), commit=False session,
TaskDependency,
col(TaskDependency.board_id).in_(board_ids),
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, TaskFingerprint, col(TaskFingerprint.board_id).in_(board_ids), commit=False session,
TaskFingerprint,
col(TaskFingerprint.board_id).in_(board_ids),
commit=False,
) )
await crud.delete_where(session, Approval, col(Approval.board_id).in_(board_ids), commit=False)
await crud.delete_where( await crud.delete_where(
session, BoardMemory, col(BoardMemory.board_id).in_(board_ids), commit=False session, Approval, col(Approval.board_id).in_(board_ids), commit=False,
)
await crud.delete_where(
session, BoardMemory, col(BoardMemory.board_id).in_(board_ids), commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, session,
@@ -259,9 +298,15 @@ async def delete_my_org(
col(OrganizationInviteBoardAccess.organization_invite_id).in_(invite_ids), col(OrganizationInviteBoardAccess.organization_invite_id).in_(invite_ids),
commit=False, commit=False,
) )
await crud.delete_where(session, Task, col(Task.board_id).in_(board_ids), commit=False) await crud.delete_where(
await crud.delete_where(session, Agent, col(Agent.board_id).in_(board_ids), commit=False) session, Task, col(Task.board_id).in_(board_ids), commit=False,
await crud.delete_where(session, Board, col(Board.organization_id) == org_id, commit=False) )
await crud.delete_where(
session, Agent, col(Agent.board_id).in_(board_ids), commit=False,
)
await crud.delete_where(
session, Board, col(Board.organization_id) == org_id, commit=False,
)
await crud.delete_where( await crud.delete_where(
session, session,
BoardGroupMemory, BoardGroupMemory,
@@ -269,9 +314,11 @@ async def delete_my_org(
commit=False, commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, BoardGroup, col(BoardGroup.organization_id) == org_id, commit=False session, BoardGroup, col(BoardGroup.organization_id) == org_id, commit=False,
)
await crud.delete_where(
session, Gateway, col(Gateway.organization_id) == org_id, commit=False,
) )
await crud.delete_where(session, Gateway, col(Gateway.organization_id) == org_id, commit=False)
await crud.delete_where( await crud.delete_where(
session, session,
OrganizationInvite, OrganizationInvite,
@@ -291,32 +338,39 @@ async def delete_my_org(
active_organization_id=None, active_organization_id=None,
commit=False, commit=False,
) )
await crud.delete_where(session, Organization, col(Organization.id) == org_id, commit=False) await crud.delete_where(
session, Organization, col(Organization.id) == org_id, commit=False,
)
await session.commit() await session.commit()
return OkResponse() return OkResponse()
@router.get("/me/member", response_model=OrganizationMemberRead) @router.get("/me/member", response_model=OrganizationMemberRead)
async def get_my_membership( async def get_my_membership(
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> OrganizationMemberRead: ) -> OrganizationMemberRead:
"""Get the caller's membership record in the active organization."""
user = await User.objects.by_id(ctx.member.user_id).first(session) user = await User.objects.by_id(ctx.member.user_id).first(session)
access_rows = await OrganizationBoardAccess.objects.filter_by( access_rows = await OrganizationBoardAccess.objects.filter_by(
organization_member_id=ctx.member.id organization_member_id=ctx.member.id,
).all(session) ).all(session)
model = _member_to_read(ctx.member, user) model = _member_to_read(ctx.member, user)
model.board_access = [ model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
for row in access_rows
] ]
return model return model
@router.get("/me/members", response_model=DefaultLimitOffsetPage[OrganizationMemberRead]) @router.get(
"/me/members", response_model=DefaultLimitOffsetPage[OrganizationMemberRead],
)
async def list_org_members( async def list_org_members(
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> DefaultLimitOffsetPage[OrganizationMemberRead]: ) -> DefaultLimitOffsetPage[OrganizationMemberRead]:
"""List members for the active organization."""
statement = ( statement = (
select(OrganizationMember, User) select(OrganizationMember, User)
.join(User, col(User.id) == col(OrganizationMember.user_id)) .join(User, col(User.id) == col(OrganizationMember.user_id))
@@ -336,9 +390,10 @@ async def list_org_members(
@router.get("/me/members/{member_id}", response_model=OrganizationMemberRead) @router.get("/me/members/{member_id}", response_model=OrganizationMemberRead)
async def get_org_member( async def get_org_member(
member_id: UUID, member_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_member), ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> OrganizationMemberRead: ) -> OrganizationMemberRead:
"""Get a specific organization member by id."""
member = await _require_org_member( member = await _require_org_member(
session, session,
organization_id=ctx.organization.id, organization_id=ctx.organization.id,
@@ -348,11 +403,12 @@ async def get_org_member(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
user = await User.objects.by_id(member.user_id).first(session) user = await User.objects.by_id(member.user_id).first(session)
access_rows = await OrganizationBoardAccess.objects.filter_by( access_rows = await OrganizationBoardAccess.objects.filter_by(
organization_member_id=member.id organization_member_id=member.id,
).all(session) ).all(session)
model = _member_to_read(member, user) model = _member_to_read(member, user)
model.board_access = [ model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
for row in access_rows
] ]
return model return model
@@ -361,9 +417,10 @@ async def get_org_member(
async def update_org_member( async def update_org_member(
member_id: UUID, member_id: UUID,
payload: OrganizationMemberUpdate, payload: OrganizationMemberUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OrganizationMemberRead: ) -> OrganizationMemberRead:
"""Update a member's role in the organization."""
member = await _require_org_member( member = await _require_org_member(
session, session,
organization_id=ctx.organization.id, organization_id=ctx.organization.id,
@@ -382,9 +439,10 @@ async def update_org_member(
async def update_member_access( async def update_member_access(
member_id: UUID, member_id: UUID,
payload: OrganizationMemberAccessUpdate, payload: OrganizationMemberAccessUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OrganizationMemberRead: ) -> OrganizationMemberRead:
"""Update board-level access settings for a member."""
member = await _require_org_member( member = await _require_org_member(
session, session,
organization_id=ctx.organization.id, organization_id=ctx.organization.id,
@@ -395,7 +453,9 @@ async def update_member_access(
if board_ids: if board_ids:
valid_board_ids = { valid_board_ids = {
board.id board.id
for board in await Board.objects.filter_by(organization_id=ctx.organization.id) for board in await Board.objects.filter_by(
organization_id=ctx.organization.id,
)
.filter(col(Board.id).in_(board_ids)) .filter(col(Board.id).in_(board_ids))
.all(session) .all(session)
} }
@@ -412,9 +472,10 @@ async def update_member_access(
@router.delete("/me/members/{member_id}", response_model=OkResponse) @router.delete("/me/members/{member_id}", response_model=OkResponse)
async def remove_org_member( async def remove_org_member(
member_id: UUID, member_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OkResponse: ) -> OkResponse:
"""Remove a member from the active organization."""
member = await _require_org_member( member = await _require_org_member(
session, session,
organization_id=ctx.organization.id, organization_id=ctx.organization.id,
@@ -432,7 +493,9 @@ async def remove_org_member(
) )
if member.role == "owner": if member.role == "owner":
owners = ( owners = (
await OrganizationMember.objects.filter_by(organization_id=ctx.organization.id) await OrganizationMember.objects.filter_by(
organization_id=ctx.organization.id,
)
.filter(col(OrganizationMember.role) == "owner") .filter(col(OrganizationMember.role) == "owner")
.all(session) .all(session)
) )
@@ -463,7 +526,9 @@ async def remove_org_member(
user.active_organization_id = fallback_membership user.active_organization_id = fallback_membership
else: else:
user.active_organization_id = ( user.active_organization_id = (
fallback_membership.organization_id if fallback_membership is not None else None fallback_membership.organization_id
if fallback_membership is not None
else None
) )
session.add(user) session.add(user)
@@ -471,11 +536,14 @@ async def remove_org_member(
return OkResponse() return OkResponse()
@router.get("/me/invites", response_model=DefaultLimitOffsetPage[OrganizationInviteRead]) @router.get(
"/me/invites", response_model=DefaultLimitOffsetPage[OrganizationInviteRead],
)
async def list_org_invites( async def list_org_invites(
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> DefaultLimitOffsetPage[OrganizationInviteRead]: ) -> DefaultLimitOffsetPage[OrganizationInviteRead]:
"""List pending invites for the active organization."""
statement = ( statement = (
OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id) OrganizationInvite.objects.filter_by(organization_id=ctx.organization.id)
.filter(col(OrganizationInvite.accepted_at).is_(None)) .filter(col(OrganizationInvite.accepted_at).is_(None))
@@ -488,9 +556,10 @@ async def list_org_invites(
@router.post("/me/invites", response_model=OrganizationInviteRead) @router.post("/me/invites", response_model=OrganizationInviteRead)
async def create_org_invite( async def create_org_invite(
payload: OrganizationInviteCreate, payload: OrganizationInviteCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OrganizationInviteRead: ) -> OrganizationInviteRead:
"""Create an organization invite for an email address."""
email = normalize_invited_email(payload.invited_email) email = normalize_invited_email(payload.invited_email)
if not email: if not email:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
@@ -526,13 +595,17 @@ async def create_org_invite(
if board_ids: if board_ids:
valid_board_ids = { valid_board_ids = {
board.id board.id
for board in await Board.objects.filter_by(organization_id=ctx.organization.id) for board in await Board.objects.filter_by(
organization_id=ctx.organization.id,
)
.filter(col(Board.id).in_(board_ids)) .filter(col(Board.id).in_(board_ids))
.all(session) .all(session)
} }
if valid_board_ids != board_ids: if valid_board_ids != board_ids:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
await apply_invite_board_access(session, invite=invite, entries=payload.board_access) await apply_invite_board_access(
session, invite=invite, entries=payload.board_access,
)
await session.commit() await session.commit()
await session.refresh(invite) await session.refresh(invite)
return OrganizationInviteRead.model_validate(invite, from_attributes=True) return OrganizationInviteRead.model_validate(invite, from_attributes=True)
@@ -541,9 +614,10 @@ async def create_org_invite(
@router.delete("/me/invites/{invite_id}", response_model=OrganizationInviteRead) @router.delete("/me/invites/{invite_id}", response_model=OrganizationInviteRead)
async def revoke_org_invite( async def revoke_org_invite(
invite_id: UUID, invite_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = Depends(require_org_admin), ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> OrganizationInviteRead: ) -> OrganizationInviteRead:
"""Revoke a pending invite from the active organization."""
invite = await _require_org_invite( invite = await _require_org_invite(
session, session,
organization_id=ctx.organization.id, organization_id=ctx.organization.id,
@@ -562,9 +636,10 @@ async def revoke_org_invite(
@router.post("/invites/accept", response_model=OrganizationMemberRead) @router.post("/invites/accept", response_model=OrganizationMemberRead)
async def accept_org_invite( async def accept_org_invite(
payload: OrganizationInviteAccept, payload: OrganizationInviteAccept,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_DEP,
) -> OrganizationMemberRead: ) -> OrganizationMemberRead:
"""Accept an invite and return resulting membership."""
if auth.user is None: if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
invite = await OrganizationInvite.objects.filter( invite = await OrganizationInvite.objects.filter(
@@ -573,11 +648,13 @@ async def accept_org_invite(
).first(session) ).first(session)
if invite is None: if invite is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if invite.invited_email and auth.user.email: if (
if normalize_invited_email(invite.invited_email) != normalize_invited_email( invite.invited_email
auth.user.email and auth.user.email
): and normalize_invited_email(invite.invited_email)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) != normalize_invited_email(auth.user.email)
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
existing = await get_member( existing = await get_member(
session, session,

View File

@@ -1,41 +1,54 @@
"""API-level thin wrapper around query-set helpers with HTTP conveniences."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Generic, TypeVar from typing import TYPE_CHECKING, Generic, TypeVar
from fastapi import HTTPException, status from fastapi import HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
from app.db.queryset import QuerySet, qs from app.db.queryset import QuerySet, qs
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
ModelT = TypeVar("ModelT") ModelT = TypeVar("ModelT")
@dataclass(frozen=True) @dataclass(frozen=True)
class APIQuerySet(Generic[ModelT]): class APIQuerySet(Generic[ModelT]):
"""Immutable query-set wrapper tailored for API-layer usage."""
queryset: QuerySet[ModelT] queryset: QuerySet[ModelT]
@property @property
def statement(self) -> SelectOfScalar[ModelT]: def statement(self) -> SelectOfScalar[ModelT]:
"""Expose the underlying SQL statement for advanced composition."""
return self.queryset.statement return self.queryset.statement
def filter(self, *criteria: Any) -> APIQuerySet[ModelT]: def filter(self, *criteria: object) -> APIQuerySet[ModelT]:
"""Return a new queryset with additional SQL criteria applied."""
return APIQuerySet(self.queryset.filter(*criteria)) return APIQuerySet(self.queryset.filter(*criteria))
def order_by(self, *ordering: Any) -> APIQuerySet[ModelT]: def order_by(self, *ordering: object) -> APIQuerySet[ModelT]:
"""Return a new queryset with ordering clauses applied."""
return APIQuerySet(self.queryset.order_by(*ordering)) return APIQuerySet(self.queryset.order_by(*ordering))
def limit(self, value: int) -> APIQuerySet[ModelT]: def limit(self, value: int) -> APIQuerySet[ModelT]:
"""Return a new queryset with a row limit applied."""
return APIQuerySet(self.queryset.limit(value)) return APIQuerySet(self.queryset.limit(value))
def offset(self, value: int) -> APIQuerySet[ModelT]: def offset(self, value: int) -> APIQuerySet[ModelT]:
"""Return a new queryset with an offset applied."""
return APIQuerySet(self.queryset.offset(value)) return APIQuerySet(self.queryset.offset(value))
async def all(self, session: AsyncSession) -> list[ModelT]: async def all(self, session: AsyncSession) -> list[ModelT]:
"""Fetch all rows for the current queryset."""
return await self.queryset.all(session) return await self.queryset.all(session)
async def first(self, session: AsyncSession) -> ModelT | None: async def first(self, session: AsyncSession) -> ModelT | None:
"""Fetch the first row for the current queryset, if present."""
return await self.queryset.first(session) return await self.queryset.first(session)
async def first_or_404( async def first_or_404(
@@ -44,6 +57,7 @@ class APIQuerySet(Generic[ModelT]):
*, *,
detail: str | None = None, detail: str | None = None,
) -> ModelT: ) -> ModelT:
"""Fetch the first row or raise HTTP 404 when no row exists."""
obj = await self.first(session) obj = await self.first(session)
if obj is not None: if obj is not None:
return obj return obj
@@ -53,4 +67,5 @@ class APIQuerySet(Generic[ModelT]):
def api_qs(model: type[ModelT]) -> APIQuerySet[ModelT]: def api_qs(model: type[ModelT]) -> APIQuerySet[ModelT]:
"""Create an APIQuerySet for a SQLModel class."""
return APIQuerySet(qs(model)) return APIQuerySet(qs(model))

View File

@@ -1,3 +1,5 @@
"""API routes for searching and fetching souls-directory markdown entries."""
from __future__ import annotations from __future__ import annotations
import re import re
@@ -13,6 +15,7 @@ from app.schemas.souls_directory import (
from app.services import souls_directory from app.services import souls_directory
router = APIRouter(prefix="/souls-directory", tags=["souls-directory"]) router = APIRouter(prefix="/souls-directory", tags=["souls-directory"])
ADMIN_OR_AGENT_DEP = Depends(require_admin_or_agent)
_SAFE_SEGMENT_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$") _SAFE_SEGMENT_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$")
_SAFE_SLUG_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$") _SAFE_SLUG_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$")
@@ -41,8 +44,9 @@ def _validate_segment(value: str, *, field: str) -> str:
async def search( async def search(
q: str = Query(default="", min_length=0), q: str = Query(default="", min_length=0),
limit: int = Query(default=20, ge=1, le=100), limit: int = Query(default=20, ge=1, le=100),
_actor: ActorContext = Depends(require_admin_or_agent), _actor: ActorContext = ADMIN_OR_AGENT_DEP,
) -> SoulsDirectorySearchResponse: ) -> SoulsDirectorySearchResponse:
"""Search souls-directory entries by handle/slug query text."""
refs = await souls_directory.list_souls_directory_refs() refs = await souls_directory.list_souls_directory_refs()
matches = souls_directory.search_souls(refs, query=q, limit=limit) matches = souls_directory.search_souls(refs, query=q, limit=limit)
items = [ items = [
@@ -62,12 +66,23 @@ async def search(
async def get_markdown( async def get_markdown(
handle: str, handle: str,
slug: str, slug: str,
_actor: ActorContext = Depends(require_admin_or_agent), _actor: ActorContext = ADMIN_OR_AGENT_DEP,
) -> SoulsDirectoryMarkdownResponse: ) -> SoulsDirectoryMarkdownResponse:
"""Fetch markdown content for a validated souls-directory handle and slug."""
safe_handle = _validate_segment(handle, field="handle") safe_handle = _validate_segment(handle, field="handle")
safe_slug = _validate_segment(slug.removesuffix(".md"), field="slug") safe_slug = _validate_segment(slug.removesuffix(".md"), field="slug")
try: try:
content = await souls_directory.fetch_soul_markdown(handle=safe_handle, slug=safe_slug) content = await souls_directory.fetch_soul_markdown(
handle=safe_handle,
slug=safe_slug,
)
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc raise HTTPException(
return SoulsDirectoryMarkdownResponse(handle=safe_handle, slug=safe_slug, content=content) status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc
return SoulsDirectoryMarkdownResponse(
handle=safe_handle,
slug=safe_slug,
content=content,
)

View File

@@ -1,17 +1,19 @@
"""Task API routes for listing, streaming, and mutating board tasks."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json import json
from collections import deque from collections import deque
from collections.abc import AsyncIterator, Sequence from collections.abc import AsyncIterator, Sequence
from contextlib import suppress
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import cast from typing import TYPE_CHECKING, cast
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, desc, or_ from sqlalchemy import asc, desc, or_
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select from sqlmodel.sql.expression import Select
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@@ -23,13 +25,16 @@ from app.api.deps import (
require_admin_auth, require_admin_auth,
require_admin_or_agent, require_admin_or_agent,
) )
from app.core.auth import AuthContext
from app.core.time import utcnow from app.core.time import utcnow
from app.db import crud from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.activity_events import ActivityEvent from app.models.activity_events import ActivityEvent
from app.models.agents import Agent from app.models.agents import Agent
from app.models.approvals import Approval from app.models.approvals import Approval
@@ -41,7 +46,13 @@ from app.models.tasks import Task
from app.schemas.common import OkResponse from app.schemas.common import OkResponse
from app.schemas.errors import BlockedTaskError from app.schemas.errors import BlockedTaskError
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate from app.schemas.tasks import (
TaskCommentCreate,
TaskCommentRead,
TaskCreate,
TaskRead,
TaskUpdate,
)
from app.services.activity_log import record_activity from app.services.activity_log import record_activity
from app.services.mentions import extract_mentions, matches_agent_mention from app.services.mentions import extract_mentions, matches_agent_mention
from app.services.organizations import require_board_access from app.services.organizations import require_board_access
@@ -54,6 +65,11 @@ from app.services.task_dependencies import (
validate_dependency_update, validate_dependency_update,
) )
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.auth import AuthContext
router = APIRouter(prefix="/boards/{board_id}/tasks", tags=["tasks"]) router = APIRouter(prefix="/boards/{board_id}/tasks", tags=["tasks"])
ALLOWED_STATUSES = {"inbox", "in_progress", "review", "done"} ALLOWED_STATUSES = {"inbox", "in_progress", "review", "done"}
@@ -66,6 +82,14 @@ TASK_EVENT_TYPES = {
SSE_SEEN_MAX = 2000 SSE_SEEN_MAX = 2000
TASK_SNIPPET_MAX_LEN = 500 TASK_SNIPPET_MAX_LEN = 500
TASK_SNIPPET_TRUNCATED_LEN = 497 TASK_SNIPPET_TRUNCATED_LEN = 497
BOARD_READ_DEP = Depends(get_board_for_actor_read)
ACTOR_DEP = Depends(require_admin_or_agent)
SINCE_QUERY = Query(default=None)
STATUS_QUERY = Query(default=None, alias="status")
BOARD_WRITE_DEP = Depends(get_board_for_user_write)
SESSION_DEP = Depends(get_session)
ADMIN_AUTH_DEP = Depends(require_admin_auth)
TASK_DEP = Depends(get_task_or_404)
def _comment_validation_error() -> HTTPException: def _comment_validation_error() -> HTTPException:
@@ -98,6 +122,7 @@ async def has_valid_recent_comment(
agent_id: UUID | None, agent_id: UUID | None,
since: datetime | None, since: datetime | None,
) -> bool: ) -> bool:
"""Check whether the task has a recent non-empty comment by the agent."""
if agent_id is None or since is None: if agent_id is None or since is None:
return False return False
statement = ( statement = (
@@ -180,8 +205,8 @@ async def _reconcile_dependents_for_dependency_toggle(
await session.exec( await session.exec(
select(Task) select(Task)
.where(col(Task.board_id) == board_id) .where(col(Task.board_id) == board_id)
.where(col(Task.id).in_(dependent_ids)) .where(col(Task.id).in_(dependent_ids)),
) ),
) )
reopened = previous_status == "done" and dependency_task.status != "done" reopened = previous_status == "done" and dependency_task.status != "done"
@@ -204,7 +229,10 @@ async def _reconcile_dependents_for_dependency_toggle(
session, session,
event_type="task.status_changed", event_type="task.status_changed",
task_id=dependent.id, task_id=dependent.id,
message=f"Task returned to inbox: dependency reopened ({dependency_task.title}).", message=(
"Task returned to inbox: dependency reopened "
f"({dependency_task.title})."
),
agent_id=actor_agent_id, agent_id=actor_agent_id,
) )
else: else:
@@ -230,7 +258,9 @@ async def _fetch_task_events(
board_id: UUID, board_id: UUID,
since: datetime, since: datetime,
) -> list[tuple[ActivityEvent, Task | None]]: ) -> list[tuple[ActivityEvent, Task | None]]:
task_ids = list(await session.exec(select(Task.id).where(col(Task.board_id) == board_id))) task_ids = list(
await session.exec(select(Task.id).where(col(Task.board_id) == board_id)),
)
if not task_ids: if not task_ids:
return [] return []
statement = cast( statement = cast(
@@ -249,7 +279,9 @@ def _serialize_comment(event: ActivityEvent) -> dict[str, object]:
return TaskCommentRead.model_validate(event).model_dump(mode="json") return TaskCommentRead.model_validate(event).model_dump(mode="json")
async def _gateway_config(session: AsyncSession, board: Board) -> GatewayClientConfig | None: async def _gateway_config(
session: AsyncSession, board: Board,
) -> GatewayClientConfig | None:
if not board.gateway_id: if not board.gateway_id:
return None return None
gateway = await Gateway.objects.by_id(board.gateway_id).first(session) gateway = await Gateway.objects.by_id(board.gateway_id).first(session)
@@ -303,7 +335,10 @@ async def _notify_agent_on_task_assign(
message = ( message = (
"TASK ASSIGNED\n" "TASK ASSIGNED\n"
+ "\n".join(details) + "\n".join(details)
+ "\n\nTake action: open the task and begin work. Post updates as task comments." + (
"\n\nTake action: open the task and begin work. "
"Post updates as task comments."
)
) )
try: try:
await _send_agent_task_message( await _send_agent_task_message(
@@ -442,17 +477,18 @@ async def _notify_lead_on_task_unassigned(
@router.get("/stream") @router.get("/stream")
async def stream_tasks( async def stream_tasks( # noqa: C901
request: Request, request: Request,
board: Board = Depends(get_board_for_actor_read), board: Board = BOARD_READ_DEP,
actor: ActorContext = Depends(require_admin_or_agent), _actor: ActorContext = ACTOR_DEP,
since: str | None = Query(default=None), since: str | None = SINCE_QUERY,
) -> EventSourceResponse: ) -> EventSourceResponse:
"""Stream task and task-comment events as SSE payloads."""
since_dt = _parse_since(since) or utcnow() since_dt = _parse_since(since) or utcnow()
seen_ids: set[UUID] = set() seen_ids: set[UUID] = set()
seen_queue: deque[UUID] = deque() seen_queue: deque[UUID] = deque()
async def event_generator() -> AsyncIterator[dict[str, str]]: async def event_generator() -> AsyncIterator[dict[str, str]]: # noqa: C901
last_seen = since_dt last_seen = since_dt
while True: while True:
if await request.is_disconnected(): if await request.is_disconnected():
@@ -510,7 +546,7 @@ async def stream_tasks(
"depends_on_task_ids": dep_list, "depends_on_task_ids": dep_list,
"blocked_by_task_ids": blocked_by, "blocked_by_task_ids": blocked_by,
"is_blocked": bool(blocked_by), "is_blocked": bool(blocked_by),
} },
) )
.model_dump(mode="json") .model_dump(mode="json")
) )
@@ -521,14 +557,15 @@ async def stream_tasks(
@router.get("", response_model=DefaultLimitOffsetPage[TaskRead]) @router.get("", response_model=DefaultLimitOffsetPage[TaskRead])
async def list_tasks( async def list_tasks( # noqa: C901
status_filter: str | None = Query(default=None, alias="status"), status_filter: str | None = STATUS_QUERY,
assigned_agent_id: UUID | None = None, assigned_agent_id: UUID | None = None,
unassigned: bool | None = None, unassigned: bool | None = None,
board: Board = Depends(get_board_for_actor_read), board: Board = BOARD_READ_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), _actor: ActorContext = ACTOR_DEP,
) -> DefaultLimitOffsetPage[TaskRead]: ) -> DefaultLimitOffsetPage[TaskRead]:
"""List board tasks with optional status and assignment filters."""
statement = select(Task).where(Task.board_id == board.id) statement = select(Task).where(Task.board_id == board.id)
if status_filter: if status_filter:
statuses = [s.strip() for s in status_filter.split(",") if s.strip()] statuses = [s.strip() for s in status_filter.split(",") if s.strip()]
@@ -550,7 +587,9 @@ async def list_tasks(
if not tasks: if not tasks:
return [] return []
task_ids = [task.id for task in tasks] task_ids = [task.id for task in tasks]
deps_map = await dependency_ids_by_task_id(session, board_id=board.id, task_ids=task_ids) deps_map = await dependency_ids_by_task_id(
session, board_id=board.id, task_ids=task_ids,
)
dep_ids: list[UUID] = [] dep_ids: list[UUID] = []
for value in deps_map.values(): for value in deps_map.values():
dep_ids.extend(value) dep_ids.extend(value)
@@ -563,7 +602,9 @@ async def list_tasks(
output: list[TaskRead] = [] output: list[TaskRead] = []
for task in tasks: for task in tasks:
dep_list = deps_map.get(task.id, []) dep_list = deps_map.get(task.id, [])
blocked_by = blocked_by_dependency_ids(dependency_ids=dep_list, status_by_id=dep_status) blocked_by = blocked_by_dependency_ids(
dependency_ids=dep_list, status_by_id=dep_status,
)
if task.status == "done": if task.status == "done":
blocked_by = [] blocked_by = []
output.append( output.append(
@@ -572,8 +613,8 @@ async def list_tasks(
"depends_on_task_ids": dep_list, "depends_on_task_ids": dep_list,
"blocked_by_task_ids": blocked_by, "blocked_by_task_ids": blocked_by,
"is_blocked": bool(blocked_by), "is_blocked": bool(blocked_by),
} },
) ),
) )
return output return output
@@ -583,10 +624,11 @@ async def list_tasks(
@router.post("", response_model=TaskRead, responses={409: {"model": BlockedTaskError}}) @router.post("", response_model=TaskRead, responses={409: {"model": BlockedTaskError}})
async def create_task( async def create_task(
payload: TaskCreate, payload: TaskCreate,
board: Board = Depends(get_board_for_user_write), board: Board = BOARD_WRITE_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(require_admin_auth), auth: AuthContext = ADMIN_AUTH_DEP,
) -> TaskRead: ) -> TaskRead:
"""Create a task and initialize dependency rows."""
data = payload.model_dump() data = payload.model_dump()
depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or []) depends_on_task_ids = cast(list[UUID], data.pop("depends_on_task_ids", []) or [])
@@ -606,7 +648,9 @@ async def create_task(
board_id=board.id, board_id=board.id,
dependency_ids=normalized_deps, dependency_ids=normalized_deps,
) )
blocked_by = blocked_by_dependency_ids(dependency_ids=normalized_deps, status_by_id=dep_status) blocked_by = blocked_by_dependency_ids(
dependency_ids=normalized_deps, status_by_id=dep_status,
)
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"): if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
raise _blocked_task_error(blocked_by) raise _blocked_task_error(blocked_by)
session.add(task) session.add(task)
@@ -618,7 +662,7 @@ async def create_task(
board_id=board.id, board_id=board.id,
task_id=task.id, task_id=task.id,
depends_on_task_id=dep_id, depends_on_task_id=dep_id,
) ),
) )
await session.commit() await session.commit()
await session.refresh(task) await session.refresh(task)
@@ -632,7 +676,9 @@ async def create_task(
await session.commit() await session.commit()
await _notify_lead_on_task_create(session=session, board=board, task=task) await _notify_lead_on_task_create(session=session, board=board, task=task)
if task.assigned_agent_id: if task.assigned_agent_id:
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session) assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
session,
)
if assigned_agent: if assigned_agent:
await _notify_agent_on_task_assign( await _notify_agent_on_task_assign(
session=session, session=session,
@@ -645,7 +691,7 @@ async def create_task(
"depends_on_task_ids": normalized_deps, "depends_on_task_ids": normalized_deps,
"blocked_by_task_ids": blocked_by, "blocked_by_task_ids": blocked_by,
"is_blocked": bool(blocked_by), "is_blocked": bool(blocked_by),
} },
) )
@@ -654,12 +700,13 @@ async def create_task(
response_model=TaskRead, response_model=TaskRead,
responses={409: {"model": BlockedTaskError}}, responses={409: {"model": BlockedTaskError}},
) )
async def update_task( async def update_task( # noqa: C901, PLR0912, PLR0915
payload: TaskUpdate, payload: TaskUpdate,
task: Task = Depends(get_task_or_404), task: Task = TASK_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = ACTOR_DEP,
) -> TaskRead: ) -> TaskRead:
"""Update task status, assignment, comment, and dependency state."""
if task.board_id is None: if task.board_id is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -676,7 +723,9 @@ async def update_task(
previous_assigned = task.assigned_agent_id previous_assigned = task.assigned_agent_id
updates = payload.model_dump(exclude_unset=True) updates = payload.model_dump(exclude_unset=True)
comment = updates.pop("comment", None) comment = updates.pop("comment", None)
depends_on_task_ids = cast(list[UUID] | None, updates.pop("depends_on_task_ids", None)) depends_on_task_ids = cast(
list[UUID] | None, updates.pop("depends_on_task_ids", None),
)
requested_fields = set(updates) requested_fields = set(updates)
if comment is not None: if comment is not None:
@@ -685,7 +734,9 @@ async def update_task(
requested_fields.add("depends_on_task_ids") requested_fields.add("depends_on_task_ids")
async def _current_dep_ids() -> list[UUID]: async def _current_dep_ids() -> list[UUID]:
deps_map = await dependency_ids_by_task_id(session, board_id=board_id, task_ids=[task.id]) deps_map = await dependency_ids_by_task_id(
session, board_id=board_id, task_ids=[task.id],
)
return deps_map.get(task.id, []) return deps_map.get(task.id, [])
async def _blocked_by(dep_ids: Sequence[UUID]) -> list[UUID]: async def _blocked_by(dep_ids: Sequence[UUID]) -> list[UUID]:
@@ -696,16 +747,20 @@ async def update_task(
board_id=board_id, board_id=board_id,
dependency_ids=list(dep_ids), dependency_ids=list(dep_ids),
) )
return blocked_by_dependency_ids(dependency_ids=list(dep_ids), status_by_id=dep_status) return blocked_by_dependency_ids(
dependency_ids=list(dep_ids), status_by_id=dep_status,
)
# Lead agent: delegation only (assign/unassign, resolve review, manage dependencies). # Lead agent: delegation only.
# Assign/unassign, resolve review, and manage dependencies.
if actor.actor_type == "agent" and actor.agent and actor.agent.is_board_lead: if actor.actor_type == "agent" and actor.agent and actor.agent.is_board_lead:
allowed_fields = {"assigned_agent_id", "status", "depends_on_task_ids"} allowed_fields = {"assigned_agent_id", "status", "depends_on_task_ids"}
if comment is not None or not requested_fields.issubset(allowed_fields): if comment is not None or not requested_fields.issubset(allowed_fields):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=( detail=(
"Board leads can only assign/unassign tasks, update dependencies, or resolve review tasks." "Board leads can only assign/unassign tasks, update "
"dependencies, or resolve review tasks."
), ),
) )
@@ -745,7 +800,11 @@ async def update_task(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Board leads cannot assign tasks to themselves.", detail="Board leads cannot assign tasks to themselves.",
) )
if agent.board_id and task.board_id and agent.board_id != task.board_id: if (
agent.board_id
and task.board_id
and agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_409_CONFLICT) raise HTTPException(status_code=status.HTTP_409_CONFLICT)
task.assigned_agent_id = agent.id task.assigned_agent_id = agent.id
else: else:
@@ -755,12 +814,18 @@ async def update_task(
if task.status != "review": if task.status != "review":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Board leads can only change status when a task is in review.", detail=(
"Board leads can only change status when a task is "
"in review."
),
) )
if updates["status"] not in {"done", "inbox"}: if updates["status"] not in {"done", "inbox"}:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Board leads can only move review tasks to done or inbox.", detail=(
"Board leads can only move review tasks to done "
"or inbox."
),
) )
if updates["status"] == "inbox": if updates["status"] == "inbox":
task.assigned_agent_id = None task.assigned_agent_id = None
@@ -793,7 +858,9 @@ async def update_task(
await session.refresh(task) await session.refresh(task)
if task.assigned_agent_id and task.assigned_agent_id != previous_assigned: if task.assigned_agent_id and task.assigned_agent_id != previous_assigned:
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session) assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
session,
)
if assigned_agent: if assigned_agent:
board = ( board = (
await Board.objects.by_id(task.board_id).first(session) await Board.objects.by_id(task.board_id).first(session)
@@ -817,14 +884,18 @@ async def update_task(
"depends_on_task_ids": dep_ids, "depends_on_task_ids": dep_ids,
"blocked_by_task_ids": blocked_ids, "blocked_by_task_ids": blocked_ids,
"is_blocked": bool(blocked_ids), "is_blocked": bool(blocked_ids),
} },
) )
# Non-lead agent: can only change status + comment, and cannot start blocked tasks. # Non-lead agent: can only change status + comment, and cannot start blocked tasks.
if actor.actor_type == "agent": if actor.actor_type == "agent":
if actor.agent and actor.agent.board_id and task.board_id: if (
if actor.agent.board_id != task.board_id: actor.agent
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) and actor.agent.board_id
and task.board_id
and actor.agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
allowed_fields = {"status", "comment"} allowed_fields = {"status", "comment"}
if depends_on_task_ids is not None or not set(updates).issubset(allowed_fields): if depends_on_task_ids is not None or not set(updates).issubset(allowed_fields):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -858,14 +929,16 @@ async def update_task(
) )
effective_deps = ( effective_deps = (
admin_normalized_deps if admin_normalized_deps is not None else await _current_dep_ids() admin_normalized_deps
if admin_normalized_deps is not None
else await _current_dep_ids()
) )
blocked_ids = await _blocked_by(effective_deps) blocked_ids = await _blocked_by(effective_deps)
target_status = cast(str, updates.get("status", task.status)) target_status = cast(str, updates.get("status", task.status))
if blocked_ids and not (task.status == "done" and target_status == "done"): if blocked_ids and not (task.status == "done" and target_status == "done"):
# Blocked tasks cannot be assigned or moved out of inbox. If the task is already in # Blocked tasks cannot be assigned or moved out of inbox.
# flight, force it back to inbox and unassign it. # If the task is already in flight, force it back to inbox and unassign it.
task.status = "inbox" task.status = "inbox"
task.assigned_agent_id = None task.assigned_agent_id = None
task.in_progress_at = None task.in_progress_at = None
@@ -910,7 +983,9 @@ async def update_task(
event_type="task.comment", event_type="task.comment",
message=comment, message=comment,
task_id=task.id, task_id=task.id,
agent_id=actor.agent.id if actor.actor_type == "agent" and actor.agent else None, agent_id=actor.agent.id
if actor.actor_type == "agent" and actor.agent
else None,
) )
session.add(event) session.add(event)
await session.commit() await session.commit()
@@ -921,7 +996,9 @@ async def update_task(
else: else:
event_type = "task.updated" event_type = "task.updated"
message = f"Task updated: {task.title}." message = f"Task updated: {task.title}."
actor_agent_id = actor.agent.id if actor.actor_type == "agent" and actor.agent else None actor_agent_id = (
actor.agent.id if actor.actor_type == "agent" and actor.agent else None
)
record_activity( record_activity(
session, session,
event_type=event_type, event_type=event_type,
@@ -938,23 +1015,34 @@ async def update_task(
) )
await session.commit() await session.commit()
if task.status == "inbox" and task.assigned_agent_id is None: if (
if previous_status != "inbox" or previous_assigned is not None: task.status == "inbox"
board = ( and task.assigned_agent_id is None
await Board.objects.by_id(task.board_id).first(session) if task.board_id else None and (previous_status != "inbox" or previous_assigned is not None)
):
board = (
await Board.objects.by_id(task.board_id).first(session)
if task.board_id
else None
)
if board:
await _notify_lead_on_task_unassigned(
session=session,
board=board,
task=task,
) )
if board:
await _notify_lead_on_task_unassigned(
session=session,
board=board,
task=task,
)
if task.assigned_agent_id and task.assigned_agent_id != previous_assigned: if task.assigned_agent_id and task.assigned_agent_id != previous_assigned:
if actor.actor_type == "agent" and actor.agent and task.assigned_agent_id == actor.agent.id: if (
actor.actor_type == "agent"
and actor.agent
and task.assigned_agent_id == actor.agent.id
):
# Don't notify the actor about their own assignment. # Don't notify the actor about their own assignment.
pass pass
else: else:
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session) assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
session,
)
if assigned_agent: if assigned_agent:
board = ( board = (
await Board.objects.by_id(task.board_id).first(session) await Board.objects.by_id(task.board_id).first(session)
@@ -978,16 +1066,17 @@ async def update_task(
"depends_on_task_ids": dep_ids, "depends_on_task_ids": dep_ids,
"blocked_by_task_ids": blocked_ids, "blocked_by_task_ids": blocked_ids,
"is_blocked": bool(blocked_ids), "is_blocked": bool(blocked_ids),
} },
) )
@router.delete("/{task_id}", response_model=OkResponse) @router.delete("/{task_id}", response_model=OkResponse)
async def delete_task( async def delete_task(
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
task: Task = Depends(get_task_or_404), task: Task = TASK_DEP,
auth: AuthContext = Depends(require_admin_auth), auth: AuthContext = ADMIN_AUTH_DEP,
) -> OkResponse: ) -> OkResponse:
"""Delete a task and related records."""
if task.board_id is None: if task.board_id is None:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
board = await Board.objects.by_id(task.board_id).first(session) board = await Board.objects.by_id(task.board_id).first(session)
@@ -997,12 +1086,14 @@ async def delete_task(
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
await require_board_access(session, user=auth.user, board=board, write=True) await require_board_access(session, user=auth.user, board=board, write=True)
await crud.delete_where( await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.task_id) == task.id, commit=False session, ActivityEvent, col(ActivityEvent.task_id) == task.id, commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, TaskFingerprint, col(TaskFingerprint.task_id) == task.id, commit=False session, TaskFingerprint, col(TaskFingerprint.task_id) == task.id, commit=False,
)
await crud.delete_where(
session, Approval, col(Approval.task_id) == task.id, commit=False,
) )
await crud.delete_where(session, Approval, col(Approval.task_id) == task.id, commit=False)
await crud.delete_where( await crud.delete_where(
session, session,
TaskDependency, TaskDependency,
@@ -1017,11 +1108,14 @@ async def delete_task(
return OkResponse() return OkResponse()
@router.get("/{task_id}/comments", response_model=DefaultLimitOffsetPage[TaskCommentRead]) @router.get(
"/{task_id}/comments", response_model=DefaultLimitOffsetPage[TaskCommentRead],
)
async def list_task_comments( async def list_task_comments(
task: Task = Depends(get_task_or_404), task: Task = TASK_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> DefaultLimitOffsetPage[TaskCommentRead]: ) -> DefaultLimitOffsetPage[TaskCommentRead]:
"""List comments for a task in chronological order."""
statement = ( statement = (
select(ActivityEvent) select(ActivityEvent)
.where(col(ActivityEvent.task_id) == task.id) .where(col(ActivityEvent.task_id) == task.id)
@@ -1032,12 +1126,13 @@ async def list_task_comments(
@router.post("/{task_id}/comments", response_model=TaskCommentRead) @router.post("/{task_id}/comments", response_model=TaskCommentRead)
async def create_task_comment( async def create_task_comment( # noqa: C901, PLR0912
payload: TaskCommentCreate, payload: TaskCommentCreate,
task: Task = Depends(get_task_or_404), task: Task = TASK_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
actor: ActorContext = Depends(require_admin_or_agent), actor: ActorContext = ACTOR_DEP,
) -> ActivityEvent: ) -> ActivityEvent:
"""Create a task comment and notify relevant agents."""
if task.board_id is None: if task.board_id is None:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
if actor.actor_type == "user" and actor.user is not None: if actor.actor_type == "user" and actor.user is not None:
@@ -1045,22 +1140,28 @@ async def create_task_comment(
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await require_board_access(session, user=actor.user, board=board, write=True) await require_board_access(session, user=actor.user, board=board, write=True)
if actor.actor_type == "agent" and actor.agent: if (
if actor.agent.is_board_lead and task.status != "review": actor.actor_type == "agent"
if not await _lead_was_mentioned(session, task, actor.agent) and not _lead_created_task( and actor.agent
task, actor.agent and actor.agent.is_board_lead
): and task.status != "review"
raise HTTPException( and not await _lead_was_mentioned(session, task, actor.agent)
status_code=status.HTTP_403_FORBIDDEN, and not _lead_created_task(task, actor.agent)
detail=( ):
"Board leads can only comment during review, when mentioned, or on tasks they created." raise HTTPException(
), status_code=status.HTTP_403_FORBIDDEN,
) detail=(
"Board leads can only comment during review, when mentioned, "
"or on tasks they created."
),
)
event = ActivityEvent( event = ActivityEvent(
event_type="task.comment", event_type="task.comment",
message=payload.message, message=payload.message,
task_id=task.id, task_id=task.id,
agent_id=actor.agent.id if actor.actor_type == "agent" and actor.agent else None, agent_id=actor.agent.id
if actor.actor_type == "agent" and actor.agent
else None,
) )
session.add(event) session.add(event)
await session.commit() await session.commit()
@@ -1072,17 +1173,27 @@ async def create_task_comment(
if matches_agent_mention(agent, mention_names): if matches_agent_mention(agent, mention_names):
targets[agent.id] = agent targets[agent.id] = agent
if not mention_names and task.assigned_agent_id: if not mention_names and task.assigned_agent_id:
assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(session) assigned_agent = await Agent.objects.by_id(task.assigned_agent_id).first(
session,
)
if assigned_agent: if assigned_agent:
targets[assigned_agent.id] = assigned_agent targets[assigned_agent.id] = assigned_agent
if actor.actor_type == "agent" and actor.agent: if actor.actor_type == "agent" and actor.agent:
targets.pop(actor.agent.id, None) targets.pop(actor.agent.id, None)
if targets: if targets:
board = await Board.objects.by_id(task.board_id).first(session) if task.board_id else None board = (
await Board.objects.by_id(task.board_id).first(session)
if task.board_id
else None
)
config = await _gateway_config(session, board) if board else None config = await _gateway_config(session, board) if board else None
if board and config: if board and config:
snippet = _truncate_snippet(payload.message) snippet = _truncate_snippet(payload.message)
actor_name = actor.agent.name if actor.actor_type == "agent" and actor.agent else "User" actor_name = (
actor.agent.name
if actor.actor_type == "agent" and actor.agent
else "User"
)
for agent in targets.values(): for agent in targets.values():
if not agent.openclaw_session_id: if not agent.openclaw_session_id:
continue continue
@@ -1101,15 +1212,14 @@ async def create_task_comment(
f"From: {actor_name}\n\n" f"From: {actor_name}\n\n"
f"{action_line}\n\n" f"{action_line}\n\n"
f"Comment:\n{snippet}\n\n" f"Comment:\n{snippet}\n\n"
"If you are mentioned but not assigned, reply in the task thread but do not change task status." "If you are mentioned but not assigned, reply in the task "
"thread but do not change task status."
) )
try: with suppress(OpenClawGatewayError):
await _send_agent_task_message( await _send_agent_task_message(
session_key=agent.openclaw_session_id, session_key=agent.openclaw_session_id,
config=config, config=config,
agent_name=agent.name, agent_name=agent.name,
message=message, message=message,
) )
except OpenClawGatewayError:
pass
return event return event

View File

@@ -1,18 +1,28 @@
"""User self-service API endpoints for profile retrieval and updates."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.auth import AuthContext, get_auth_context from app.core.auth import AuthContext, get_auth_context
from app.db.session import get_session from app.db.session import get_session
from app.models.users import User
from app.schemas.users import UserRead, UserUpdate from app.schemas.users import UserRead, UserUpdate
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.users import User
router = APIRouter(prefix="/users", tags=["users"]) router = APIRouter(prefix="/users", tags=["users"])
AUTH_CONTEXT_DEP = Depends(get_auth_context)
SESSION_DEP = Depends(get_session)
@router.get("/me", response_model=UserRead) @router.get("/me", response_model=UserRead)
async def get_me(auth: AuthContext = Depends(get_auth_context)) -> UserRead: async def get_me(auth: AuthContext = AUTH_CONTEXT_DEP) -> UserRead:
"""Return the authenticated user's current profile payload."""
if auth.actor_type != "user" or auth.user is None: if auth.actor_type != "user" or auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return UserRead.model_validate(auth.user) return UserRead.model_validate(auth.user)
@@ -21,9 +31,10 @@ async def get_me(auth: AuthContext = Depends(get_auth_context)) -> UserRead:
@router.patch("/me", response_model=UserRead) @router.patch("/me", response_model=UserRead)
async def update_me( async def update_me(
payload: UserUpdate, payload: UserUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = AUTH_CONTEXT_DEP,
) -> UserRead: ) -> UserRead:
"""Apply partial profile updates for the authenticated user."""
if auth.actor_type != "user" or auth.user is None: if auth.actor_type != "user" or auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
updates = payload.model_dump(exclude_unset=True) updates = payload.model_dump(exclude_unset=True)

View File

@@ -0,0 +1 @@
"""Core utilities and configuration for the backend service."""

View File

@@ -1,33 +1,44 @@
"""Agent authentication helpers for token-backed API access."""
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from typing import Literal from typing import TYPE_CHECKING, Literal
from fastapi import Depends, Header, HTTPException, Request, status from fastapi import Depends, Header, HTTPException, Request, status
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_tokens import verify_agent_token from app.core.agent_tokens import verify_agent_token
from app.core.time import utcnow from app.core.time import utcnow
from app.db.session import get_session from app.db.session import get_session
from app.models.agents import Agent from app.models.agents import Agent
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_LAST_SEEN_TOUCH_INTERVAL = timedelta(seconds=30) _LAST_SEEN_TOUCH_INTERVAL = timedelta(seconds=30)
_SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"}) _SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})
SESSION_DEP = Depends(get_session)
@dataclass @dataclass
class AgentAuthContext: class AgentAuthContext:
"""Authenticated actor payload for agent-originated requests."""
actor_type: Literal["agent"] actor_type: Literal["agent"]
agent: Agent agent: Agent
async def _find_agent_for_token(session: AsyncSession, token: str) -> Agent | None: async def _find_agent_for_token(session: AsyncSession, token: str) -> Agent | None:
agents = list(await session.exec(select(Agent).where(col(Agent.agent_token_hash).is_not(None)))) agents = list(
await session.exec(
select(Agent).where(col(Agent.agent_token_hash).is_not(None)),
),
)
for agent in agents: for agent in agents:
if agent.agent_token_hash and verify_agent_token(token, agent.agent_token_hash): if agent.agent_token_hash and verify_agent_token(token, agent.agent_token_hash):
return agent return agent
@@ -65,9 +76,11 @@ async def _touch_agent_presence(
calls (task comments, memory updates, etc). Touch presence so the UI reflects calls (task comments, memory updates, etc). Touch presence so the UI reflects
real activity even if the heartbeat loop isn't running. real activity even if the heartbeat loop isn't running.
""" """
now = utcnow() now = utcnow()
if agent.last_seen_at is not None and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL: if (
agent.last_seen_at is not None
and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL
):
return return
agent.last_seen_at = now agent.last_seen_at = now
@@ -86,9 +99,14 @@ async def get_agent_auth_context(
request: Request, request: Request,
agent_token: str | None = Header(default=None, alias="X-Agent-Token"), agent_token: str | None = Header(default=None, alias="X-Agent-Token"),
authorization: str | None = Header(default=None, alias="Authorization"), authorization: str | None = Header(default=None, alias="Authorization"),
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> AgentAuthContext: ) -> AgentAuthContext:
resolved = _resolve_agent_token(agent_token, authorization, accept_authorization=True) """Require and validate agent auth token from request headers."""
resolved = _resolve_agent_token(
agent_token,
authorization,
accept_authorization=True,
)
if not resolved: if not resolved:
logger.warning( logger.warning(
"agent auth missing token path=%s x_agent=%s authorization=%s", "agent auth missing token path=%s x_agent=%s authorization=%s",
@@ -113,8 +131,9 @@ async def get_agent_auth_context_optional(
request: Request, request: Request,
agent_token: str | None = Header(default=None, alias="X-Agent-Token"), agent_token: str | None = Header(default=None, alias="X-Agent-Token"),
authorization: str | None = Header(default=None, alias="Authorization"), authorization: str | None = Header(default=None, alias="Authorization"),
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> AgentAuthContext | None: ) -> AgentAuthContext | None:
"""Optionally resolve agent auth context from `X-Agent-Token` only."""
resolved = _resolve_agent_token( resolved = _resolve_agent_token(
agent_token, agent_token,
authorization, authorization,

View File

@@ -1,3 +1,5 @@
"""Token generation and verification helpers for agent authentication."""
from __future__ import annotations from __future__ import annotations
import base64 import base64
@@ -10,6 +12,7 @@ SALT_BYTES = 16
def generate_agent_token() -> str: def generate_agent_token() -> str:
"""Generate a new URL-safe random token for an agent."""
return secrets.token_urlsafe(32) return secrets.token_urlsafe(32)
@@ -23,12 +26,14 @@ def _b64decode(value: str) -> bytes:
def hash_agent_token(token: str) -> str: def hash_agent_token(token: str) -> str:
"""Hash an agent token using PBKDF2-HMAC-SHA256 with a random salt."""
salt = secrets.token_bytes(SALT_BYTES) salt = secrets.token_bytes(SALT_BYTES)
digest = hashlib.pbkdf2_hmac("sha256", token.encode("utf-8"), salt, ITERATIONS) digest = hashlib.pbkdf2_hmac("sha256", token.encode("utf-8"), salt, ITERATIONS)
return f"pbkdf2_sha256${ITERATIONS}${_b64encode(salt)}${_b64encode(digest)}" return f"pbkdf2_sha256${ITERATIONS}${_b64encode(salt)}${_b64encode(digest)}"
def verify_agent_token(token: str, stored_hash: str) -> bool: def verify_agent_token(token: str, stored_hash: str) -> bool:
"""Verify a plaintext token against a stored PBKDF2 hash representation."""
try: try:
algorithm, iterations, salt_b64, digest_b64 = stored_hash.split("$") algorithm, iterations, salt_b64, digest_b64 = stored_hash.split("$")
except ValueError: except ValueError:
@@ -41,5 +46,10 @@ def verify_agent_token(token: str, stored_hash: str) -> bool:
return False return False
salt = _b64decode(salt_b64) salt = _b64decode(salt_b64)
expected_digest = _b64decode(digest_b64) expected_digest = _b64decode(digest_b64)
candidate = hashlib.pbkdf2_hmac("sha256", token.encode("utf-8"), salt, iterations_int) candidate = hashlib.pbkdf2_hmac(
"sha256",
token.encode("utf-8"),
salt,
iterations_int,
)
return hmac.compare_digest(candidate, expected_digest) return hmac.compare_digest(candidate, expected_digest)

View File

@@ -1,32 +1,42 @@
"""User authentication helpers backed by Clerk JWT verification."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import Literal from typing import TYPE_CHECKING, Literal
from fastapi import Depends, HTTPException, Request, status from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi_clerk_auth import ClerkConfig, ClerkHTTPBearer from fastapi_clerk_auth import ClerkConfig, ClerkHTTPBearer
from fastapi_clerk_auth import HTTPAuthorizationCredentials as ClerkCredentials from fastapi_clerk_auth import HTTPAuthorizationCredentials as ClerkCredentials
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings from app.core.config import settings
from app.db import crud from app.db import crud
from app.db.session import get_session from app.db.session import get_session
from app.models.users import User from app.models.users import User
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
SECURITY_DEP = Depends(security)
SESSION_DEP = Depends(get_session)
CLERK_JWKS_URL_REQUIRED_ERROR = "CLERK_JWKS_URL is not set."
class ClerkTokenPayload(BaseModel): class ClerkTokenPayload(BaseModel):
"""JWT claims payload shape required from Clerk tokens."""
sub: str sub: str
@lru_cache @lru_cache
def _build_clerk_http_bearer(auto_error: bool) -> ClerkHTTPBearer: def _build_clerk_http_bearer(*, auto_error: bool) -> ClerkHTTPBearer:
"""Create and cache the Clerk HTTP bearer guard."""
if not settings.clerk_jwks_url: if not settings.clerk_jwks_url:
raise RuntimeError("CLERK_JWKS_URL is not set.") raise RuntimeError(CLERK_JWKS_URL_REQUIRED_ERROR)
clerk_config = ClerkConfig( clerk_config = ClerkConfig(
jwks_url=settings.clerk_jwks_url, jwks_url=settings.clerk_jwks_url,
verify_iat=settings.clerk_verify_iat, verify_iat=settings.clerk_verify_iat,
@@ -37,12 +47,15 @@ def _build_clerk_http_bearer(auto_error: bool) -> ClerkHTTPBearer:
@dataclass @dataclass
class AuthContext: class AuthContext:
"""Authenticated user context resolved from inbound auth headers."""
actor_type: Literal["user"] actor_type: Literal["user"]
user: User | None = None user: User | None = None
def _resolve_clerk_auth( def _resolve_clerk_auth(
request: Request, fallback: ClerkCredentials | None request: Request,
fallback: ClerkCredentials | None,
) -> ClerkCredentials | None: ) -> ClerkCredentials | None:
auth_data = getattr(request.state, "clerk_auth", None) auth_data = getattr(request.state, "clerk_auth", None)
if isinstance(auth_data, ClerkCredentials): if isinstance(auth_data, ClerkCredentials):
@@ -59,9 +72,10 @@ def _parse_subject(auth_data: ClerkCredentials | None) -> str | None:
async def get_auth_context( async def get_auth_context(
request: Request, request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(security), credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> AuthContext: ) -> AuthContext:
"""Resolve required authenticated user context from Clerk JWT headers."""
if credentials is None: if credentials is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
@@ -109,9 +123,10 @@ async def get_auth_context(
async def get_auth_context_optional( async def get_auth_context_optional(
request: Request, request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(security), credentials: HTTPAuthorizationCredentials | None = SECURITY_DEP,
session: AsyncSession = Depends(get_session), session: AsyncSession = SESSION_DEP,
) -> AuthContext | None: ) -> AuthContext | None:
"""Resolve user context if available, otherwise return `None`."""
if request.headers.get("X-Agent-Token"): if request.headers.get("X-Agent-Token"):
return None return None
if credentials is None: if credentials is None:

View File

@@ -1,3 +1,5 @@
"""Application settings and environment configuration loading."""
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
@@ -11,6 +13,8 @@ DEFAULT_ENV_FILE = BACKEND_ROOT / ".env"
class Settings(BaseSettings): class Settings(BaseSettings):
"""Typed runtime configuration sourced from environment variables."""
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
# Load `backend/.env` regardless of current working directory. # Load `backend/.env` regardless of current working directory.
# (Important when running uvicorn from repo root or via a process manager.) # (Important when running uvicorn from repo root or via a process manager.)
@@ -32,8 +36,8 @@ class Settings(BaseSettings):
base_url: str = "" base_url: str = ""
# Optional: local directory where the backend is allowed to write "preserved" agent # Optional: local directory where the backend is allowed to write "preserved" agent
# workspace files (e.g. USER.md/SELF.md/MEMORY.md). If empty, local writes are disabled # workspace files (e.g. USER.md/SELF.md/MEMORY.md). If empty, local
# and provisioning relies on the gateway API. # writes are disabled and provisioning relies on the gateway API.
# #
# Security note: do NOT point this at arbitrary system paths in production. # Security note: do NOT point this at arbitrary system paths in production.
local_agent_workspace_root: str = "" local_agent_workspace_root: str = ""
@@ -48,8 +52,8 @@ class Settings(BaseSettings):
@model_validator(mode="after") @model_validator(mode="after")
def _defaults(self) -> Self: def _defaults(self) -> Self:
# In dev, default to applying Alembic migrations at startup to avoid schema drift # In dev, default to applying Alembic migrations at startup to avoid
# (e.g. missing newly-added columns). # schema drift (e.g. missing newly-added columns).
if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev": if "db_auto_migrate" not in self.model_fields_set and self.environment == "dev":
self.db_auto_migrate = True self.db_auto_migrate = True
return self return self

View File

@@ -1,8 +1,13 @@
"""Utilities for parsing human-readable duration schedule strings."""
from __future__ import annotations from __future__ import annotations
import re import re
_DURATION_RE = re.compile(r"^(?P<num>[1-9]\\d*)\\s*(?P<unit>[smhdw])$", flags=re.IGNORECASE) _DURATION_RE = re.compile(
r"^(?P<num>[1-9]\\d*)\\s*(?P<unit>[smhdw])$",
flags=re.IGNORECASE,
)
_MULTIPLIERS: dict[str, int] = { _MULTIPLIERS: dict[str, int] = {
"s": 1, "s": 1,
@@ -11,26 +16,36 @@ _MULTIPLIERS: dict[str, int] = {
"d": 60 * 60 * 24, "d": 60 * 60 * 24,
"w": 60 * 60 * 24 * 7, "w": 60 * 60 * 24 * 7,
} }
_MAX_SCHEDULE_SECONDS = 60 * 60 * 24 * 365 * 10
_ERR_SCHEDULE_REQUIRED = "schedule is required"
_ERR_SCHEDULE_INVALID = (
'Invalid schedule. Expected format like "10m", "1h", "2d", "1w".'
)
_ERR_SCHEDULE_NONPOSITIVE = "Schedule must be greater than 0."
_ERR_SCHEDULE_TOO_LARGE = "Schedule is too large (max 10 years)."
def normalize_every(value: str) -> str: def normalize_every(value: str) -> str:
"""Normalize schedule string to lower-case compact unit form."""
normalized = value.strip().lower().replace(" ", "") normalized = value.strip().lower().replace(" ", "")
if not normalized: if not normalized:
raise ValueError("schedule is required") raise ValueError(_ERR_SCHEDULE_REQUIRED)
return normalized return normalized
def parse_every_to_seconds(value: str) -> int: def parse_every_to_seconds(value: str) -> int:
"""Parse compact schedule syntax into a number of seconds."""
normalized = normalize_every(value) normalized = normalize_every(value)
match = _DURATION_RE.match(normalized) match = _DURATION_RE.match(normalized)
if not match: if not match:
raise ValueError('Invalid schedule. Expected format like "10m", "1h", "2d", "1w".') raise ValueError(_ERR_SCHEDULE_INVALID)
num = int(match.group("num")) num = int(match.group("num"))
unit = match.group("unit").lower() unit = match.group("unit").lower()
seconds = num * _MULTIPLIERS[unit] seconds = num * _MULTIPLIERS[unit]
if seconds <= 0: if seconds <= 0:
raise ValueError("Schedule must be greater than 0.") raise ValueError(_ERR_SCHEDULE_NONPOSITIVE)
# Prevent accidental absurd schedules (e.g. 999999999d). # Prevent accidental absurd schedules (e.g. 999999999d).
if seconds > 60 * 60 * 24 * 365 * 10: if seconds > _MAX_SCHEDULE_SECONDS:
raise ValueError("Schedule is too large (max 10 years).") raise ValueError(_ERR_SCHEDULE_TOO_LARGE)
return seconds return seconds

View File

@@ -1,8 +1,10 @@
"""Global exception handlers and request-id middleware for FastAPI."""
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any, Final, cast from typing import TYPE_CHECKING, Any, Final, cast
from uuid import uuid4 from uuid import uuid4
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
@@ -10,7 +12,9 @@ from fastapi.exceptions import RequestValidationError, ResponseValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.responses import Response from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
if TYPE_CHECKING:
from starlette.types import ASGIApp, Message, Receive, Scope, Send
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,12 +24,16 @@ ExceptionHandler = Callable[[Request, Exception], Response | Awaitable[Response]
class RequestIdMiddleware: class RequestIdMiddleware:
"""ASGI middleware that ensures every request has a request-id."""
def __init__(self, app: ASGIApp, *, header_name: str = REQUEST_ID_HEADER) -> None: def __init__(self, app: ASGIApp, *, header_name: str = REQUEST_ID_HEADER) -> None:
"""Initialize middleware with app instance and header name."""
self._app = app self._app = app
self._header_name = header_name self._header_name = header_name
self._header_name_bytes = header_name.lower().encode("latin-1") self._header_name_bytes = header_name.lower().encode("latin-1")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Inject request-id into request state and response headers."""
if scope["type"] != "http": if scope["type"] != "http":
await self._app(scope, receive, send) await self._app(scope, receive, send)
return return
@@ -36,8 +44,11 @@ class RequestIdMiddleware:
if message["type"] == "http.response.start": if message["type"] == "http.response.start":
# Starlette uses `list[tuple[bytes, bytes]]` here. # Starlette uses `list[tuple[bytes, bytes]]` here.
headers: list[tuple[bytes, bytes]] = message.setdefault("headers", []) headers: list[tuple[bytes, bytes]] = message.setdefault("headers", [])
if not any(key.lower() == self._header_name_bytes for key, _ in headers): if not any(
headers.append((self._header_name_bytes, request_id.encode("latin-1"))) key.lower() == self._header_name_bytes for key, _ in headers
):
request_id_bytes = request_id.encode("latin-1")
headers.append((self._header_name_bytes, request_id_bytes))
await send(message) await send(message)
await self._app(scope, receive, send_with_request_id) await self._app(scope, receive, send_with_request_id)
@@ -62,8 +73,10 @@ class RequestIdMiddleware:
def install_error_handling(app: FastAPI) -> None: def install_error_handling(app: FastAPI) -> None:
"""Install middleware and exception handlers on the FastAPI app."""
# Important: add request-id middleware last so it's the outermost middleware. # Important: add request-id middleware last so it's the outermost middleware.
# This ensures it still runs even if another middleware (e.g. CORS preflight) returns early. # This ensures it still runs even if another middleware
# (e.g. CORS preflight) returns early.
app.add_middleware(RequestIdMiddleware) app.add_middleware(RequestIdMiddleware)
app.add_exception_handler( app.add_exception_handler(
@@ -88,7 +101,7 @@ def _get_request_id(request: Request) -> str | None:
return None return None
def _error_payload(*, detail: Any, request_id: str | None) -> dict[str, Any]: def _error_payload(*, detail: object, request_id: str | None) -> dict[str, object]:
payload: dict[str, Any] = {"detail": detail} payload: dict[str, Any] = {"detail": detail}
if request_id: if request_id:
payload["request_id"] = request_id payload["request_id"] = request_id
@@ -96,7 +109,8 @@ def _error_payload(*, detail: Any, request_id: str | None) -> dict[str, Any]:
async def _request_validation_handler( async def _request_validation_handler(
request: Request, exc: RequestValidationError request: Request,
exc: RequestValidationError,
) -> JSONResponse: ) -> JSONResponse:
# `RequestValidationError` is expected user input; don't log at ERROR. # `RequestValidationError` is expected user input; don't log at ERROR.
request_id = _get_request_id(request) request_id = _get_request_id(request)
@@ -107,7 +121,8 @@ async def _request_validation_handler(
async def _response_validation_handler( async def _response_validation_handler(
request: Request, exc: ResponseValidationError request: Request,
exc: ResponseValidationError,
) -> JSONResponse: ) -> JSONResponse:
request_id = _get_request_id(request) request_id = _get_request_id(request)
logger.exception( logger.exception(
@@ -125,7 +140,10 @@ async def _response_validation_handler(
) )
async def _http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse: async def _http_exception_handler(
request: Request,
exc: StarletteHTTPException,
) -> JSONResponse:
request_id = _get_request_id(request) request_id = _get_request_id(request)
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
@@ -134,11 +152,18 @@ async def _http_exception_handler(request: Request, exc: StarletteHTTPException)
) )
async def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: async def _unhandled_exception_handler(
request: Request,
_exc: Exception,
) -> JSONResponse:
request_id = _get_request_id(request) request_id = _get_request_id(request)
logger.exception( logger.exception(
"unhandled_exception", "unhandled_exception",
extra={"request_id": request_id, "method": request.method, "path": request.url.path}, extra={
"request_id": request_id,
"method": request.method,
"path": request.url.path,
},
) )
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,

View File

@@ -1,3 +1,5 @@
"""Application logging configuration and formatter utilities."""
from __future__ import annotations from __future__ import annotations
import json import json
@@ -15,7 +17,8 @@ TRACE_LEVEL = 5
logging.addLevelName(TRACE_LEVEL, "TRACE") logging.addLevelName(TRACE_LEVEL, "TRACE")
def _trace(self: logging.Logger, message: str, *args: Any, **kwargs: Any) -> None: def _trace(self: logging.Logger, message: str, *args: object, **kwargs: object) -> None:
"""Log a TRACE-level message when the logger is TRACE-enabled."""
if self.isEnabledFor(TRACE_LEVEL): if self.isEnabledFor(TRACE_LEVEL):
self._log(TRACE_LEVEL, message, args, **kwargs) self._log(TRACE_LEVEL, message, args, **kwargs)
@@ -52,21 +55,31 @@ _STANDARD_LOG_RECORD_ATTRS = {
class AppLogFilter(logging.Filter): class AppLogFilter(logging.Filter):
"""Inject app metadata into each log record."""
def __init__(self, app_name: str, version: str) -> None: def __init__(self, app_name: str, version: str) -> None:
"""Initialize the filter with fixed app and version values."""
super().__init__() super().__init__()
self._app_name = app_name self._app_name = app_name
self._version = version self._version = version
def filter(self, record: logging.LogRecord) -> bool: def filter(self, record: logging.LogRecord) -> bool:
"""Attach app metadata fields to each emitted record."""
record.app = self._app_name record.app = self._app_name
record.version = self._version record.version = self._version
return True return True
class JsonFormatter(logging.Formatter): class JsonFormatter(logging.Formatter):
"""Formatter that serializes log records as compact JSON."""
def format(self, record: logging.LogRecord) -> str: def format(self, record: logging.LogRecord) -> str:
"""Render a single log record into a JSON string."""
payload: dict[str, Any] = { payload: dict[str, Any] = {
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(), "timestamp": datetime.fromtimestamp(
record.created,
tz=timezone.utc,
).isoformat(),
"level": record.levelname, "level": record.levelname,
"logger": record.name, "logger": record.name,
"message": record.getMessage(), "message": record.getMessage(),
@@ -88,7 +101,10 @@ class JsonFormatter(logging.Formatter):
class KeyValueFormatter(logging.Formatter): class KeyValueFormatter(logging.Formatter):
"""Formatter that appends extra fields as `key=value` pairs."""
def format(self, record: logging.LogRecord) -> str: def format(self, record: logging.LogRecord) -> str:
"""Render a log line with appended non-standard record fields."""
base = super().format(record) base = super().format(record)
extras = { extras = {
key: value key: value
@@ -102,6 +118,8 @@ class KeyValueFormatter(logging.Formatter):
class AppLogger: class AppLogger:
"""Centralized logging setup utility for the backend process."""
_configured = False _configured = False
@classmethod @classmethod
@@ -111,10 +129,12 @@ class AppLogger:
return level_name, TRACE_LEVEL return level_name, TRACE_LEVEL
if level_name.isdigit(): if level_name.isdigit():
return level_name, int(level_name) return level_name, int(level_name)
return level_name, logging._nameToLevel.get(level_name, logging.INFO) levels = logging.getLevelNamesMapping()
return level_name, levels.get(level_name, logging.INFO)
@classmethod @classmethod
def configure(cls, *, force: bool = False) -> None: def configure(cls, *, force: bool = False) -> None:
"""Configure root logging handlers, formatters, and library levels."""
if cls._configured and not force: if cls._configured and not force:
return return
@@ -127,7 +147,8 @@ class AppLogger:
formatter: logging.Formatter = JsonFormatter() formatter: logging.Formatter = JsonFormatter()
else: else:
formatter = KeyValueFormatter( formatter = KeyValueFormatter(
"%(asctime)s %(levelname)s %(name)s %(message)s app=%(app)s version=%(version)s" "%(asctime)s %(levelname)s %(name)s %(message)s "
"app=%(app)s version=%(version)s",
) )
if settings.log_use_utc: if settings.log_use_utc:
formatter.converter = time.gmtime formatter.converter = time.gmtime
@@ -160,10 +181,12 @@ class AppLogger:
@classmethod @classmethod
def get_logger(cls, name: str | None = None) -> logging.Logger: def get_logger(cls, name: str | None = None) -> logging.Logger:
"""Return a logger, ensuring logging has been configured."""
if not cls._configured: if not cls._configured:
cls.configure() cls.configure()
return logging.getLogger(name) return logging.getLogger(name)
def configure_logging() -> None: def configure_logging() -> None:
"""Configure global application logging once during startup."""
AppLogger.configure() AppLogger.configure()

View File

@@ -1,3 +1,5 @@
"""Time-related helpers shared across backend modules."""
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime from datetime import UTC, datetime
@@ -5,6 +7,5 @@ from datetime import UTC, datetime
def utcnow() -> datetime: def utcnow() -> datetime:
"""Return a naive UTC datetime without using deprecated datetime.utcnow().""" """Return a naive UTC datetime without using deprecated datetime.utcnow()."""
# Keep naive UTC values for compatibility with existing DB schema/queries. # Keep naive UTC values for compatibility with existing DB schema/queries.
return datetime.now(UTC).replace(tzinfo=None) return datetime.now(UTC).replace(tzinfo=None)

View File

@@ -1,2 +1,4 @@
"""Application name and version constants."""
APP_NAME = "mission-control" APP_NAME = "mission-control"
APP_VERSION = "0.1.0" APP_VERSION = "0.1.0"

View File

@@ -0,0 +1 @@
"""Database helpers and abstractions for backend persistence."""

View File

@@ -1,14 +1,18 @@
"""Typed wrapper around fastapi-pagination for backend query helpers."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable, Sequence from collections.abc import Awaitable, Callable, Sequence
from typing import Any, TypeVar, cast from typing import TYPE_CHECKING, Any, TypeVar, cast
from fastapi_pagination.ext.sqlalchemy import paginate as _paginate from fastapi_pagination.ext.sqlalchemy import paginate as _paginate
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select, SelectOfScalar
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select, SelectOfScalar
T = TypeVar("T") T = TypeVar("T")
Transformer = Callable[[Sequence[Any]], Sequence[Any] | Awaitable[Sequence[Any]]] Transformer = Callable[[Sequence[Any]], Sequence[Any] | Awaitable[Sequence[Any]]]
@@ -20,8 +24,10 @@ async def paginate(
*, *,
transformer: Transformer | None = None, transformer: Transformer | None = None,
) -> DefaultLimitOffsetPage[T]: ) -> DefaultLimitOffsetPage[T]:
# fastapi-pagination is not fully typed (it returns Any), but response_model validation """Execute a paginated query and cast to the project page type alias."""
# ensures runtime correctness. Centralize casts here to keep strict mypy clean. # fastapi-pagination is not fully typed (it returns Any), but response_model
# validation ensures runtime correctness. Centralize casts here to keep strict
# mypy clean.
return cast( return cast(
DefaultLimitOffsetPage[T], DefaultLimitOffsetPage[T],
await _paginate(session, statement, transformer=transformer), await _paginate(session, statement, transformer=transformer),

View File

@@ -1,7 +1,9 @@
"""Model manager descriptor utilities for query-set style access."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Generic, TypeVar from typing import Generic, TypeVar
from sqlalchemy import false from sqlalchemy import false
from sqlmodel import SQLModel, col from sqlmodel import SQLModel, col
@@ -13,41 +15,55 @@ ModelT = TypeVar("ModelT", bound=SQLModel)
@dataclass(frozen=True) @dataclass(frozen=True)
class ModelManager(Generic[ModelT]): class ModelManager(Generic[ModelT]):
"""Convenience query manager bound to a SQLModel class."""
model: type[ModelT] model: type[ModelT]
id_field: str = "id" id_field: str = "id"
def all(self) -> QuerySet[ModelT]: def all(self) -> QuerySet[ModelT]:
"""Return an unfiltered queryset for the bound model."""
return qs(self.model) return qs(self.model)
def none(self) -> QuerySet[ModelT]: def none(self) -> QuerySet[ModelT]:
"""Return a queryset that yields no rows."""
return qs(self.model).filter(false()) return qs(self.model).filter(false())
def filter(self, *criteria: Any) -> QuerySet[ModelT]: def filter(self, *criteria: object) -> QuerySet[ModelT]:
"""Return queryset filtered by SQL criteria expressions."""
return self.all().filter(*criteria) return self.all().filter(*criteria)
def where(self, *criteria: Any) -> QuerySet[ModelT]: def where(self, *criteria: object) -> QuerySet[ModelT]:
"""Alias for `filter`."""
return self.filter(*criteria) return self.filter(*criteria)
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]: def filter_by(self, **kwargs: object) -> QuerySet[ModelT]:
"""Return queryset filtered by model field equality values."""
queryset = self.all() queryset = self.all()
for field_name, value in kwargs.items(): for field_name, value in kwargs.items():
queryset = queryset.filter(col(getattr(self.model, field_name)) == value) queryset = queryset.filter(col(getattr(self.model, field_name)) == value)
return queryset return queryset
def by_id(self, obj_id: Any) -> QuerySet[ModelT]: def by_id(self, obj_id: object) -> QuerySet[ModelT]:
"""Return queryset filtered by primary identifier field."""
return self.by_field(self.id_field, obj_id) return self.by_field(self.id_field, obj_id)
def by_ids(self, obj_ids: list[Any] | tuple[Any, ...] | set[Any]) -> QuerySet[ModelT]: def by_ids(
self,
obj_ids: list[object] | tuple[object, ...] | set[object],
) -> QuerySet[ModelT]:
"""Return queryset filtered by a set/list/tuple of identifiers."""
return self.by_field_in(self.id_field, obj_ids) return self.by_field_in(self.id_field, obj_ids)
def by_field(self, field_name: str, value: Any) -> QuerySet[ModelT]: def by_field(self, field_name: str, value: object) -> QuerySet[ModelT]:
"""Return queryset filtered by a single field equality check."""
return self.filter(col(getattr(self.model, field_name)) == value) return self.filter(col(getattr(self.model, field_name)) == value)
def by_field_in( def by_field_in(
self, self,
field_name: str, field_name: str,
values: list[Any] | tuple[Any, ...] | set[Any], values: list[object] | tuple[object, ...] | set[object],
) -> QuerySet[ModelT]: ) -> QuerySet[ModelT]:
"""Return queryset filtered by `field IN values` semantics."""
seq = tuple(values) seq = tuple(values)
if not seq: if not seq:
return self.none() return self.none()
@@ -55,5 +71,8 @@ class ModelManager(Generic[ModelT]):
class ManagerDescriptor(Generic[ModelT]): class ManagerDescriptor(Generic[ModelT]):
"""Descriptor that exposes a model-bound `ModelManager` as `.objects`."""
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]: def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
"""Return a fresh manager bound to the owning model class."""
return ModelManager(owner) return ModelManager(owner)

View File

@@ -1,50 +1,67 @@
"""Lightweight immutable query-set wrapper for SQLModel statements."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from typing import Any, Generic, TypeVar from typing import TYPE_CHECKING, Generic, TypeVar
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
ModelT = TypeVar("ModelT") ModelT = TypeVar("ModelT")
@dataclass(frozen=True) @dataclass(frozen=True)
class QuerySet(Generic[ModelT]): class QuerySet(Generic[ModelT]):
"""Composable immutable wrapper around a SQLModel scalar select statement."""
statement: SelectOfScalar[ModelT] statement: SelectOfScalar[ModelT]
def filter(self, *criteria: Any) -> QuerySet[ModelT]: def filter(self, *criteria: object) -> QuerySet[ModelT]:
"""Return a new queryset with additional SQL criteria."""
return replace(self, statement=self.statement.where(*criteria)) return replace(self, statement=self.statement.where(*criteria))
def where(self, *criteria: Any) -> QuerySet[ModelT]: def where(self, *criteria: object) -> QuerySet[ModelT]:
"""Alias for `filter` to mirror SQLAlchemy naming."""
return self.filter(*criteria) return self.filter(*criteria)
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]: def filter_by(self, **kwargs: object) -> QuerySet[ModelT]:
"""Return a new queryset filtered by keyword-equality criteria."""
statement = self.statement.filter_by(**kwargs) statement = self.statement.filter_by(**kwargs)
return replace(self, statement=statement) return replace(self, statement=statement)
def order_by(self, *ordering: Any) -> QuerySet[ModelT]: def order_by(self, *ordering: object) -> QuerySet[ModelT]:
"""Return a new queryset with ordering clauses applied."""
return replace(self, statement=self.statement.order_by(*ordering)) return replace(self, statement=self.statement.order_by(*ordering))
def limit(self, value: int) -> QuerySet[ModelT]: def limit(self, value: int) -> QuerySet[ModelT]:
"""Return a new queryset with a SQL row limit."""
return replace(self, statement=self.statement.limit(value)) return replace(self, statement=self.statement.limit(value))
def offset(self, value: int) -> QuerySet[ModelT]: def offset(self, value: int) -> QuerySet[ModelT]:
"""Return a new queryset with a SQL row offset."""
return replace(self, statement=self.statement.offset(value)) return replace(self, statement=self.statement.offset(value))
async def all(self, session: AsyncSession) -> list[ModelT]: async def all(self, session: AsyncSession) -> list[ModelT]:
"""Execute and return all rows for the current queryset."""
return list(await session.exec(self.statement)) return list(await session.exec(self.statement))
async def first(self, session: AsyncSession) -> ModelT | None: async def first(self, session: AsyncSession) -> ModelT | None:
"""Execute and return the first row, if available."""
return (await session.exec(self.statement)).first() return (await session.exec(self.statement)).first()
async def one_or_none(self, session: AsyncSession) -> ModelT | None: async def one_or_none(self, session: AsyncSession) -> ModelT | None:
"""Execute and return one row or `None`."""
return (await session.exec(self.statement)).one_or_none() return (await session.exec(self.statement)).one_or_none()
async def exists(self, session: AsyncSession) -> bool: async def exists(self, session: AsyncSession) -> bool:
"""Return whether the queryset yields at least one row."""
return await self.limit(1).first(session) is not None return await self.limit(1).first(session) is not None
def qs(model: type[ModelT]) -> QuerySet[ModelT]: def qs(model: type[ModelT]) -> QuerySet[ModelT]:
"""Create a base queryset for a SQLModel class."""
return QuerySet(select(model)) return QuerySet(select(model))

View File

@@ -1,8 +1,10 @@
"""Database engine, session factory, and startup migration helpers."""
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections.abc import AsyncGenerator
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
import anyio import anyio
from alembic import command from alembic import command
@@ -15,6 +17,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app import models as _models from app import models as _models
from app.core.config import settings from app.core.config import settings
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
# Import model modules so SQLModel metadata is fully registered at startup. # Import model modules so SQLModel metadata is fully registered at startup.
_MODEL_REGISTRY = _models _MODEL_REGISTRY = _models
@@ -48,12 +53,14 @@ def _alembic_config() -> Config:
def run_migrations() -> None: def run_migrations() -> None:
"""Apply Alembic migrations to the latest revision."""
logger.info("Running database migrations.") logger.info("Running database migrations.")
command.upgrade(_alembic_config(), "head") command.upgrade(_alembic_config(), "head")
logger.info("Database migrations complete.") logger.info("Database migrations complete.")
async def init_db() -> None: async def init_db() -> None:
"""Initialize database schema, running migrations when configured."""
if settings.db_auto_migrate: if settings.db_auto_migrate:
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions" versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
if any(versions_dir.glob("*.py")): if any(versions_dir.glob("*.py")):
@@ -67,6 +74,7 @@ async def init_db() -> None:
async def get_session() -> AsyncGenerator[AsyncSession, None]: async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""Yield a request-scoped async DB session with safe rollback on errors."""
async with async_session_maker() as session: async with async_session_maker() as session:
try: try:
yield session yield session

View File

@@ -0,0 +1 @@
"""External system integration clients and protocol adapters."""

View File

@@ -1,3 +1,5 @@
"""OpenClaw gateway protocol constants shared across integration layers."""
from __future__ import annotations from __future__ import annotations
PROTOCOL_VERSION = 3 PROTOCOL_VERSION = 3
@@ -116,4 +118,5 @@ GATEWAY_EVENTS_SET = frozenset(GATEWAY_EVENTS)
def is_known_gateway_method(method: str) -> bool: def is_known_gateway_method(method: str) -> bool:
"""Return whether a method name is part of the known base gateway methods."""
return method in GATEWAY_METHODS_SET return method in GATEWAY_METHODS_SET

View File

@@ -1,7 +1,9 @@
"""FastAPI application entrypoint and router wiring for the backend."""
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import TYPE_CHECKING
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@@ -29,11 +31,15 @@ from app.core.error_handling import install_error_handling
from app.core.logging import configure_logging from app.core.logging import configure_logging
from app.db.session import init_db from app.db.session import init_db
if TYPE_CHECKING:
from collections.abc import AsyncIterator
configure_logging() configure_logging()
@asynccontextmanager @asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]: async def lifespan(_: FastAPI) -> AsyncIterator[None]:
"""Initialize application resources before serving requests."""
await init_db() await init_db()
yield yield
@@ -55,16 +61,19 @@ install_error_handling(app)
@app.get("/health") @app.get("/health")
def health() -> dict[str, bool]: def health() -> dict[str, bool]:
"""Lightweight liveness probe endpoint."""
return {"ok": True} return {"ok": True}
@app.get("/healthz") @app.get("/healthz")
def healthz() -> dict[str, bool]: def healthz() -> dict[str, bool]:
"""Alias liveness probe endpoint for platform compatibility."""
return {"ok": True} return {"ok": True}
@app.get("/readyz") @app.get("/readyz")
def readyz() -> dict[str, bool]: def readyz() -> dict[str, bool]:
"""Readiness probe endpoint for service orchestration checks."""
return {"ok": True} return {"ok": True}

View File

@@ -1,3 +1,5 @@
"""Model exports for SQLAlchemy/SQLModel metadata discovery."""
from app.models.activity_events import ActivityEvent from app.models.activity_events import ActivityEvent
from app.models.agents import Agent from app.models.agents import Agent
from app.models.approvals import Approval from app.models.approvals import Approval

View File

@@ -1,6 +1,8 @@
"""Activity event model persisted for audit and feed use-cases."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlmodel import Field from sqlmodel import Field
@@ -10,6 +12,8 @@ from app.models.base import QueryModel
class ActivityEvent(QueryModel, table=True): class ActivityEvent(QueryModel, table=True):
"""Discrete activity event tied to tasks and agents."""
__tablename__ = "activity_events" __tablename__ = "activity_events"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Agent model representing autonomous actors assigned to boards."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from typing import Any from typing import Any
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@@ -12,6 +14,8 @@ from app.models.base import QueryModel
class Agent(QueryModel, table=True): class Agent(QueryModel, table=True):
"""Agent configuration and lifecycle state persisted in the database."""
__tablename__ = "agents" __tablename__ = "agents"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
@@ -20,8 +24,14 @@ class Agent(QueryModel, table=True):
status: str = Field(default="provisioning", index=True) status: str = Field(default="provisioning", index=True)
openclaw_session_id: str | None = Field(default=None, index=True) openclaw_session_id: str | None = Field(default=None, index=True)
agent_token_hash: str | None = Field(default=None, index=True) agent_token_hash: str | None = Field(default=None, index=True)
heartbeat_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON)) heartbeat_config: dict[str, Any] | None = Field(
identity_profile: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON)) default=None,
sa_column=Column(JSON),
)
identity_profile: dict[str, Any] | None = Field(
default=None,
sa_column=Column(JSON),
)
identity_template: str | None = Field(default=None, sa_column=Column(Text)) identity_template: str | None = Field(default=None, sa_column=Column(Text))
soul_template: str | None = Field(default=None, sa_column=Column(Text)) soul_template: str | None = Field(default=None, sa_column=Column(Text))
provision_requested_at: datetime | None = Field(default=None) provision_requested_at: datetime | None = Field(default=None)

View File

@@ -1,6 +1,8 @@
"""Approval model storing pending and resolved approval actions."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import JSON, Column from sqlalchemy import JSON, Column
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class Approval(QueryModel, table=True): class Approval(QueryModel, table=True):
"""Approval request and decision metadata for gated operations."""
__tablename__ = "approvals" __tablename__ = "approvals"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,3 +1,5 @@
"""Base model mixins and shared SQLModel abstractions."""
from __future__ import annotations from __future__ import annotations
from typing import ClassVar, Self from typing import ClassVar, Self
@@ -8,4 +10,6 @@ from app.db.query_manager import ManagerDescriptor
class QueryModel(SQLModel, table=False): class QueryModel(SQLModel, table=False):
"""Base SQLModel with a shared query manager descriptor."""
objects: ClassVar[ManagerDescriptor[Self]] = ManagerDescriptor() objects: ClassVar[ManagerDescriptor[Self]] = ManagerDescriptor()

View File

@@ -1,6 +1,8 @@
"""Board-group scoped memory entries for shared context."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import JSON, Column from sqlalchemy import JSON, Column
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class BoardGroupMemory(QueryModel, table=True): class BoardGroupMemory(QueryModel, table=True):
"""Persisted memory items associated with a board group."""
__tablename__ = "board_group_memory" __tablename__ = "board_group_memory"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Board group model used to organize boards inside organizations."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlmodel import Field from sqlmodel import Field
@@ -10,6 +12,8 @@ from app.models.tenancy import TenantScoped
class BoardGroup(TenantScoped, table=True): class BoardGroup(TenantScoped, table=True):
"""Logical grouping container for boards within an organization."""
__tablename__ = "board_groups" __tablename__ = "board_groups"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Board-level memory entries for persistent contextual state."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import JSON, Column from sqlalchemy import JSON, Column
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class BoardMemory(QueryModel, table=True): class BoardMemory(QueryModel, table=True):
"""Persisted memory item attached directly to a board."""
__tablename__ = "board_memory" __tablename__ = "board_memory"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Board onboarding session model for guided setup state."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import JSON, Column from sqlalchemy import JSON, Column
@@ -11,13 +13,18 @@ from app.models.base import QueryModel
class BoardOnboardingSession(QueryModel, table=True): class BoardOnboardingSession(QueryModel, table=True):
"""Persisted onboarding conversation and draft goal data for a board."""
__tablename__ = "board_onboarding_sessions" __tablename__ = "board_onboarding_sessions"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
board_id: UUID = Field(foreign_key="boards.id", index=True) board_id: UUID = Field(foreign_key="boards.id", index=True)
session_key: str session_key: str
status: str = Field(default="active", index=True) status: str = Field(default="active", index=True)
messages: list[dict[str, object]] | None = Field(default=None, sa_column=Column(JSON)) messages: list[dict[str, object]] | None = Field(
default=None,
sa_column=Column(JSON),
)
draft_goal: dict[str, object] | None = Field(default=None, sa_column=Column(JSON)) draft_goal: dict[str, object] | None = Field(default=None, sa_column=Column(JSON))
created_at: datetime = Field(default_factory=utcnow) created_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow) updated_at: datetime = Field(default_factory=utcnow)

View File

@@ -1,6 +1,8 @@
"""Board model for organization workspaces and goal configuration."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import JSON, Column from sqlalchemy import JSON, Column
@@ -11,6 +13,8 @@ from app.models.tenancy import TenantScoped
class Board(TenantScoped, table=True): class Board(TenantScoped, table=True):
"""Primary board entity grouping tasks, agents, and goal metadata."""
__tablename__ = "boards" __tablename__ = "boards"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
@@ -18,10 +22,17 @@ class Board(TenantScoped, table=True):
name: str name: str
slug: str = Field(index=True) slug: str = Field(index=True)
gateway_id: UUID | None = Field(default=None, foreign_key="gateways.id", index=True) gateway_id: UUID | None = Field(default=None, foreign_key="gateways.id", index=True)
board_group_id: UUID | None = Field(default=None, foreign_key="board_groups.id", index=True) board_group_id: UUID | None = Field(
default=None,
foreign_key="board_groups.id",
index=True,
)
board_type: str = Field(default="goal", index=True) board_type: str = Field(default="goal", index=True)
objective: str | None = None objective: str | None = None
success_metrics: dict[str, object] | None = Field(default=None, sa_column=Column(JSON)) success_metrics: dict[str, object] | None = Field(
default=None,
sa_column=Column(JSON),
)
target_date: datetime | None = None target_date: datetime | None = None
goal_confirmed: bool = Field(default=False) goal_confirmed: bool = Field(default=False)
goal_source: str | None = None goal_source: str | None = None

View File

@@ -1,6 +1,8 @@
"""Gateway model storing organization-level gateway integration metadata."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlmodel import Field from sqlmodel import Field
@@ -10,6 +12,8 @@ from app.models.base import QueryModel
class Gateway(QueryModel, table=True): class Gateway(QueryModel, table=True):
"""Configured external gateway endpoint and authentication settings."""
__tablename__ = "gateways" __tablename__ = "gateways"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Board-level access grants assigned to organization members."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class OrganizationBoardAccess(QueryModel, table=True): class OrganizationBoardAccess(QueryModel, table=True):
"""Member-specific board permissions within an organization."""
__tablename__ = "organization_board_access" __tablename__ = "organization_board_access"
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(
@@ -21,7 +25,10 @@ class OrganizationBoardAccess(QueryModel, table=True):
) )
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
organization_member_id: UUID = Field(foreign_key="organization_members.id", index=True) organization_member_id: UUID = Field(
foreign_key="organization_members.id",
index=True,
)
board_id: UUID = Field(foreign_key="boards.id", index=True) board_id: UUID = Field(foreign_key="boards.id", index=True)
can_read: bool = Field(default=True) can_read: bool = Field(default=True)
can_write: bool = Field(default=False) can_write: bool = Field(default=False)

View File

@@ -1,6 +1,8 @@
"""Board access grants attached to pending organization invites."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class OrganizationInviteBoardAccess(QueryModel, table=True): class OrganizationInviteBoardAccess(QueryModel, table=True):
"""Invite-specific board permissions applied after invite acceptance."""
__tablename__ = "organization_invite_board_access" __tablename__ = "organization_invite_board_access"
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(
@@ -21,7 +25,10 @@ class OrganizationInviteBoardAccess(QueryModel, table=True):
) )
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
organization_invite_id: UUID = Field(foreign_key="organization_invites.id", index=True) organization_invite_id: UUID = Field(
foreign_key="organization_invites.id",
index=True,
)
board_id: UUID = Field(foreign_key="boards.id", index=True) board_id: UUID = Field(foreign_key="boards.id", index=True)
can_read: bool = Field(default=True) can_read: bool = Field(default=True)
can_write: bool = Field(default=False) can_write: bool = Field(default=False)

View File

@@ -1,6 +1,8 @@
"""Organization invite model for email-based tenant membership flow."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class OrganizationInvite(QueryModel, table=True): class OrganizationInvite(QueryModel, table=True):
"""Invitation record granting prospective organization access."""
__tablename__ = "organization_invites" __tablename__ = "organization_invites"
__table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),) __table_args__ = (UniqueConstraint("token", name="uq_org_invites_token"),)
@@ -21,8 +25,16 @@ class OrganizationInvite(QueryModel, table=True):
role: str = Field(default="member", index=True) role: str = Field(default="member", index=True)
all_boards_read: bool = Field(default=False) all_boards_read: bool = Field(default=False)
all_boards_write: bool = Field(default=False) all_boards_write: bool = Field(default=False)
created_by_user_id: UUID | None = Field(default=None, foreign_key="users.id", index=True) created_by_user_id: UUID | None = Field(
accepted_by_user_id: UUID | None = Field(default=None, foreign_key="users.id", index=True) default=None,
foreign_key="users.id",
index=True,
)
accepted_by_user_id: UUID | None = Field(
default=None,
foreign_key="users.id",
index=True,
)
accepted_at: datetime | None = None accepted_at: datetime | None = None
created_at: datetime = Field(default_factory=utcnow) created_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow) updated_at: datetime = Field(default_factory=utcnow)

View File

@@ -1,6 +1,8 @@
"""Organization membership model with role and board-access flags."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class OrganizationMember(QueryModel, table=True): class OrganizationMember(QueryModel, table=True):
"""Membership row linking a user to an organization and permissions."""
__tablename__ = "organization_members" __tablename__ = "organization_members"
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(

View File

@@ -1,6 +1,8 @@
"""Organization model representing top-level tenant entities."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.base import QueryModel
class Organization(QueryModel, table=True): class Organization(QueryModel, table=True):
"""Top-level organization tenant record."""
__tablename__ = "organizations" __tablename__ = "organizations"
__table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),) __table_args__ = (UniqueConstraint("name", name="uq_organizations_name"),)

View File

@@ -1,6 +1,8 @@
"""Task dependency edge model for board-local dependency graphs."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlalchemy import CheckConstraint, UniqueConstraint from sqlalchemy import CheckConstraint, UniqueConstraint
@@ -11,6 +13,8 @@ from app.models.tenancy import TenantScoped
class TaskDependency(TenantScoped, table=True): class TaskDependency(TenantScoped, table=True):
"""Directed dependency edge between two tasks in the same board."""
__tablename__ = "task_dependencies" __tablename__ = "task_dependencies"
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(

View File

@@ -1,6 +1,8 @@
"""Task fingerprint model for duplicate/task-linking operations."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlmodel import Field from sqlmodel import Field
@@ -10,6 +12,8 @@ from app.models.base import QueryModel
class TaskFingerprint(QueryModel, table=True): class TaskFingerprint(QueryModel, table=True):
"""Hashed task-content fingerprint associated with a board and task."""
__tablename__ = "task_fingerprints" __tablename__ = "task_fingerprints"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)

View File

@@ -1,6 +1,8 @@
"""Task model representing board work items and execution metadata."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlmodel import Field from sqlmodel import Field
@@ -10,6 +12,8 @@ from app.models.tenancy import TenantScoped
class Task(TenantScoped, table=True): class Task(TenantScoped, table=True):
"""Board-scoped task entity with ownership, status, and timing fields."""
__tablename__ = "tasks" __tablename__ = "tasks"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
@@ -22,8 +26,16 @@ class Task(TenantScoped, table=True):
due_at: datetime | None = None due_at: datetime | None = None
in_progress_at: datetime | None = None in_progress_at: datetime | None = None
created_by_user_id: UUID | None = Field(default=None, foreign_key="users.id", index=True) created_by_user_id: UUID | None = Field(
assigned_agent_id: UUID | None = Field(default=None, foreign_key="agents.id", index=True) default=None,
foreign_key="users.id",
index=True,
)
assigned_agent_id: UUID | None = Field(
default=None,
foreign_key="agents.id",
index=True,
)
auto_created: bool = Field(default=False) auto_created: bool = Field(default=False)
auto_reason: str | None = None auto_reason: str | None = None

View File

@@ -1,7 +1,9 @@
"""Shared tenancy-scoped model base classes."""
from __future__ import annotations from __future__ import annotations
from app.models.base import QueryModel from app.models.base import QueryModel
class TenantScoped(QueryModel, table=False): class TenantScoped(QueryModel, table=False):
pass """Base class for models constrained to a tenant/organization scope."""

View File

@@ -1,3 +1,5 @@
"""User model storing identity and profile preferences."""
from __future__ import annotations from __future__ import annotations
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@@ -8,6 +10,8 @@ from app.models.base import QueryModel
class User(QueryModel, table=True): class User(QueryModel, table=True):
"""Application user account and profile attributes."""
__tablename__ = "users" __tablename__ = "users"
id: UUID = Field(default_factory=uuid4, primary_key=True) id: UUID = Field(default_factory=uuid4, primary_key=True)
@@ -21,5 +25,7 @@ class User(QueryModel, table=True):
context: str | None = None context: str | None = None
is_super_admin: bool = Field(default=False) is_super_admin: bool = Field(default=False)
active_organization_id: UUID | None = Field( active_organization_id: UUID | None = Field(
default=None, foreign_key="organizations.id", index=True default=None,
foreign_key="organizations.id",
index=True,
) )

View File

@@ -1,3 +1,5 @@
"""Public schema exports shared across API route modules."""
from app.schemas.activity_events import ActivityEventRead from app.schemas.activity_events import ActivityEventRead
from app.schemas.agents import AgentCreate, AgentRead, AgentUpdate from app.schemas.agents import AgentCreate, AgentRead, AgentUpdate
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalUpdate from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalUpdate

View File

@@ -1,12 +1,16 @@
"""Response schemas for activity events and task-comment feed items."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel from sqlmodel import SQLModel
class ActivityEventRead(SQLModel): class ActivityEventRead(SQLModel):
"""Serialized activity event payload returned by activity endpoints."""
id: UUID id: UUID
event_type: str event_type: str
message: str | None message: str | None
@@ -16,6 +20,8 @@ class ActivityEventRead(SQLModel):
class ActivityTaskCommentFeedItemRead(SQLModel): class ActivityTaskCommentFeedItemRead(SQLModel):
"""Denormalized task-comment feed item enriched with task and board fields."""
id: UUID id: UUID
created_at: datetime created_at: datetime
message: str | None message: str | None

View File

@@ -1,16 +1,21 @@
"""Schemas for approval create/update/read API payloads."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from typing import Literal, Self from typing import Literal, Self
from uuid import UUID from uuid import UUID # noqa: TCH003
from pydantic import model_validator from pydantic import model_validator
from sqlmodel import SQLModel from sqlmodel import SQLModel
ApprovalStatus = Literal["pending", "approved", "rejected"] ApprovalStatus = Literal["pending", "approved", "rejected"]
STATUS_REQUIRED_ERROR = "status is required"
class ApprovalBase(SQLModel): class ApprovalBase(SQLModel):
"""Shared approval fields used across create/read payloads."""
action_type: str action_type: str
task_id: UUID | None = None task_id: UUID | None = None
payload: dict[str, object] | None = None payload: dict[str, object] | None = None
@@ -20,20 +25,27 @@ class ApprovalBase(SQLModel):
class ApprovalCreate(ApprovalBase): class ApprovalCreate(ApprovalBase):
"""Payload for creating a new approval request."""
agent_id: UUID | None = None agent_id: UUID | None = None
class ApprovalUpdate(SQLModel): class ApprovalUpdate(SQLModel):
"""Payload for mutating approval status."""
status: ApprovalStatus | None = None status: ApprovalStatus | None = None
@model_validator(mode="after") @model_validator(mode="after")
def validate_status(self) -> Self: def validate_status(self) -> Self:
"""Ensure explicitly provided `status` is not null."""
if "status" in self.model_fields_set and self.status is None: if "status" in self.model_fields_set and self.status is None:
raise ValueError("status is required") raise ValueError(STATUS_REQUIRED_ERROR)
return self return self
class ApprovalRead(ApprovalBase): class ApprovalRead(ApprovalBase):
"""Approval payload returned from read endpoints."""
id: UUID id: UUID
board_id: UUID board_id: UUID
agent_id: UUID | None = None agent_id: UUID | None = None

View File

@@ -1,13 +1,18 @@
"""Schemas for applying heartbeat settings to board-group agents."""
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel from sqlmodel import SQLModel
class BoardGroupHeartbeatApply(SQLModel): class BoardGroupHeartbeatApply(SQLModel):
# Heartbeat cadence string understood by the OpenClaw gateway (e.g. "2m", "10m", "30m"). """Request payload for heartbeat policy updates."""
# Heartbeat cadence string understood by the OpenClaw gateway
# (e.g. "2m", "10m", "30m").
every: str every: str
# Optional heartbeat target (most deployments use "none"). # Optional heartbeat target (most deployments use "none").
target: str | None = None target: str | None = None
@@ -15,6 +20,8 @@ class BoardGroupHeartbeatApply(SQLModel):
class BoardGroupHeartbeatApplyResult(SQLModel): class BoardGroupHeartbeatApplyResult(SQLModel):
"""Result payload describing agents updated by a heartbeat request."""
board_group_id: UUID board_group_id: UUID
requested: dict[str, Any] requested: dict[str, Any]
updated_agent_ids: list[UUID] updated_agent_ids: list[UUID]

View File

@@ -1,14 +1,18 @@
"""Schemas for board-group memory create/read API payloads."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel from sqlmodel import SQLModel
from app.schemas.common import NonEmptyStr from app.schemas.common import NonEmptyStr # noqa: TCH001
class BoardGroupMemoryCreate(SQLModel): class BoardGroupMemoryCreate(SQLModel):
"""Payload for creating a board-group memory entry."""
# For writes, reject blank/whitespace-only content. # For writes, reject blank/whitespace-only content.
content: NonEmptyStr content: NonEmptyStr
tags: list[str] | None = None tags: list[str] | None = None
@@ -16,9 +20,12 @@ class BoardGroupMemoryCreate(SQLModel):
class BoardGroupMemoryRead(SQLModel): class BoardGroupMemoryRead(SQLModel):
"""Serialized board-group memory entry returned from read endpoints."""
id: UUID id: UUID
board_group_id: UUID board_group_id: UUID
# For reads, allow legacy rows that may have empty content (avoid response validation 500s). # For reads, allow legacy rows that may have empty content
# (avoid response validation 500s).
content: str content: str
tags: list[str] | None = None tags: list[str] | None = None
source: str | None = None source: str | None = None

View File

@@ -1,28 +1,36 @@
"""Schemas for board-group create/update/read API operations."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel from sqlmodel import SQLModel
class BoardGroupBase(SQLModel): class BoardGroupBase(SQLModel):
"""Shared board-group fields for create/read operations."""
name: str name: str
slug: str slug: str
description: str | None = None description: str | None = None
class BoardGroupCreate(BoardGroupBase): class BoardGroupCreate(BoardGroupBase):
pass """Payload for creating a board group."""
class BoardGroupUpdate(SQLModel): class BoardGroupUpdate(SQLModel):
"""Payload for partial board-group updates."""
name: str | None = None name: str | None = None
slug: str | None = None slug: str | None = None
description: str | None = None description: str | None = None
class BoardGroupRead(BoardGroupBase): class BoardGroupRead(BoardGroupBase):
"""Board-group payload returned from read endpoints."""
id: UUID id: UUID
organization_id: UUID organization_id: UUID
created_at: datetime created_at: datetime

View File

@@ -1,14 +1,18 @@
"""Schemas for board memory create/read API payloads."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel from sqlmodel import SQLModel
from app.schemas.common import NonEmptyStr from app.schemas.common import NonEmptyStr # noqa: TCH001
class BoardMemoryCreate(SQLModel): class BoardMemoryCreate(SQLModel):
"""Payload for creating a board memory entry."""
# For writes, reject blank/whitespace-only content. # For writes, reject blank/whitespace-only content.
content: NonEmptyStr content: NonEmptyStr
tags: list[str] | None = None tags: list[str] | None = None
@@ -16,9 +20,12 @@ class BoardMemoryCreate(SQLModel):
class BoardMemoryRead(SQLModel): class BoardMemoryRead(SQLModel):
"""Serialized board memory entry returned from read endpoints."""
id: UUID id: UUID
board_id: UUID board_id: UUID
# For reads, allow legacy rows that may have empty content (avoid response validation 500s). # For reads, allow legacy rows that may have empty content
# (avoid response validation 500s).
content: str content: str
tags: list[str] | None = None tags: list[str] | None = None
source: str | None = None source: str | None = None

View File

@@ -1,14 +1,23 @@
"""Schemas for board create/update/read API operations."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from typing import Self from typing import Self
from uuid import UUID from uuid import UUID # noqa: TCH003
from pydantic import model_validator from pydantic import model_validator
from sqlmodel import SQLModel from sqlmodel import SQLModel
_ERR_GOAL_FIELDS_REQUIRED = (
"Confirmed goal boards require objective and success_metrics"
)
_ERR_GATEWAY_REQUIRED = "gateway_id is required"
class BoardBase(SQLModel): class BoardBase(SQLModel):
"""Shared board fields used across create and read payloads."""
name: str name: str
slug: str slug: str
gateway_id: UUID | None = None gateway_id: UUID | None = None
@@ -22,17 +31,25 @@ class BoardBase(SQLModel):
class BoardCreate(BoardBase): class BoardCreate(BoardBase):
"""Payload for creating a board."""
gateway_id: UUID gateway_id: UUID
@model_validator(mode="after") @model_validator(mode="after")
def validate_goal_fields(self) -> Self: def validate_goal_fields(self) -> Self:
if self.board_type == "goal" and self.goal_confirmed: """Require goal details when creating a confirmed goal board."""
if not self.objective or not self.success_metrics: if (
raise ValueError("Confirmed goal boards require objective and success_metrics") self.board_type == "goal"
and self.goal_confirmed
and (not self.objective or not self.success_metrics)
):
raise ValueError(_ERR_GOAL_FIELDS_REQUIRED)
return self return self
class BoardUpdate(SQLModel): class BoardUpdate(SQLModel):
"""Payload for partial board updates."""
name: str | None = None name: str | None = None
slug: str | None = None slug: str | None = None
gateway_id: UUID | None = None gateway_id: UUID | None = None
@@ -46,13 +63,16 @@ class BoardUpdate(SQLModel):
@model_validator(mode="after") @model_validator(mode="after")
def validate_gateway_id(self) -> Self: def validate_gateway_id(self) -> Self:
"""Reject explicit null gateway IDs in patch payloads."""
# Treat explicit null like "unset" is invalid for patch updates. # Treat explicit null like "unset" is invalid for patch updates.
if "gateway_id" in self.model_fields_set and self.gateway_id is None: if "gateway_id" in self.model_fields_set and self.gateway_id is None:
raise ValueError("gateway_id is required") raise ValueError(_ERR_GATEWAY_REQUIRED)
return self return self
class BoardRead(BoardBase): class BoardRead(BoardBase):
"""Board payload returned from read endpoints."""
id: UUID id: UUID
organization_id: UUID organization_id: UUID
created_at: datetime created_at: datetime

View File

@@ -1,3 +1,5 @@
"""Common reusable schema primitives and simple API response envelopes."""
from __future__ import annotations from __future__ import annotations
from typing import Annotated from typing import Annotated
@@ -5,9 +7,12 @@ from typing import Annotated
from pydantic import StringConstraints from pydantic import StringConstraints
from sqlmodel import SQLModel from sqlmodel import SQLModel
# Reusable string type for request payloads where blank/whitespace-only values are invalid. # Reusable string type for request payloads where blank/whitespace-only values
# are invalid.
NonEmptyStr = Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)] NonEmptyStr = Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)]
class OkResponse(SQLModel): class OkResponse(SQLModel):
"""Standard success response payload."""
ok: bool = True ok: bool = True

View File

@@ -1,12 +1,18 @@
"""Structured error payload schemas used by API responses."""
from __future__ import annotations from __future__ import annotations
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
class BlockedTaskDetail(SQLModel): class BlockedTaskDetail(SQLModel):
"""Error detail payload listing blocking dependency task identifiers."""
message: str message: str
blocked_by_task_ids: list[str] = Field(default_factory=list) blocked_by_task_ids: list[str] = Field(default_factory=list)
class BlockedTaskError(SQLModel): class BlockedTaskError(SQLModel):
"""Top-level blocked-task error response envelope."""
detail: BlockedTaskDetail detail: BlockedTaskDetail

View File

@@ -1,15 +1,21 @@
"""Schemas for gateway passthrough API request and response payloads."""
from __future__ import annotations from __future__ import annotations
from sqlmodel import SQLModel from sqlmodel import SQLModel
from app.schemas.common import NonEmptyStr from app.schemas.common import NonEmptyStr # noqa: TCH001
class GatewaySessionMessageRequest(SQLModel): class GatewaySessionMessageRequest(SQLModel):
"""Request payload for sending a message into a gateway session."""
content: NonEmptyStr content: NonEmptyStr
class GatewayResolveQuery(SQLModel): class GatewayResolveQuery(SQLModel):
"""Query parameters used to resolve which gateway to target."""
board_id: str | None = None board_id: str | None = None
gateway_url: str | None = None gateway_url: str | None = None
gateway_token: str | None = None gateway_token: str | None = None
@@ -17,6 +23,8 @@ class GatewayResolveQuery(SQLModel):
class GatewaysStatusResponse(SQLModel): class GatewaysStatusResponse(SQLModel):
"""Aggregated gateway status response including session metadata."""
connected: bool connected: bool
gateway_url: str gateway_url: str
sessions_count: int | None = None sessions_count: int | None = None
@@ -28,20 +36,28 @@ class GatewaysStatusResponse(SQLModel):
class GatewaySessionsResponse(SQLModel): class GatewaySessionsResponse(SQLModel):
"""Gateway sessions list response payload."""
sessions: list[object] sessions: list[object]
main_session_key: str | None = None main_session_key: str | None = None
main_session: object | None = None main_session: object | None = None
class GatewaySessionResponse(SQLModel): class GatewaySessionResponse(SQLModel):
"""Single gateway session response payload."""
session: object session: object
class GatewaySessionHistoryResponse(SQLModel): class GatewaySessionHistoryResponse(SQLModel):
"""Gateway session history response payload."""
history: list[object] history: list[object]
class GatewayCommandsResponse(SQLModel): class GatewayCommandsResponse(SQLModel):
"""Gateway command catalog and protocol metadata."""
protocol_version: int protocol_version: int
methods: list[str] methods: list[str]
events: list[str] events: list[str]

View File

@@ -1,24 +1,38 @@
"""Schemas for gateway-main and lead-agent coordination endpoints."""
from __future__ import annotations from __future__ import annotations
from typing import Literal from typing import Literal
from uuid import UUID from uuid import UUID # noqa: TCH003
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
from app.schemas.common import NonEmptyStr from app.schemas.common import NonEmptyStr # noqa: TCH001
def _lead_reply_tags() -> list[str]:
return ["gateway_main", "lead_reply"]
def _user_reply_tags() -> list[str]:
return ["gateway_main", "user_reply"]
class GatewayLeadMessageRequest(SQLModel): class GatewayLeadMessageRequest(SQLModel):
"""Request payload for sending a message to a board lead agent."""
kind: Literal["question", "handoff"] = "question" kind: Literal["question", "handoff"] = "question"
correlation_id: str | None = None correlation_id: str | None = None
content: NonEmptyStr content: NonEmptyStr
# How the lead should reply (defaults are interpreted by templates). # How the lead should reply (defaults are interpreted by templates).
reply_tags: list[str] = Field(default_factory=lambda: ["gateway_main", "lead_reply"]) reply_tags: list[str] = Field(default_factory=_lead_reply_tags)
reply_source: str | None = "lead_to_gateway_main" reply_source: str | None = "lead_to_gateway_main"
class GatewayLeadMessageResponse(SQLModel): class GatewayLeadMessageResponse(SQLModel):
"""Response payload for a lead-message dispatch attempt."""
ok: bool = True ok: bool = True
board_id: UUID board_id: UUID
lead_agent_id: UUID | None = None lead_agent_id: UUID | None = None
@@ -27,15 +41,19 @@ class GatewayLeadMessageResponse(SQLModel):
class GatewayLeadBroadcastRequest(SQLModel): class GatewayLeadBroadcastRequest(SQLModel):
"""Request payload for broadcasting a message to multiple board leads."""
kind: Literal["question", "handoff"] = "question" kind: Literal["question", "handoff"] = "question"
correlation_id: str | None = None correlation_id: str | None = None
content: NonEmptyStr content: NonEmptyStr
board_ids: list[UUID] | None = None board_ids: list[UUID] | None = None
reply_tags: list[str] = Field(default_factory=lambda: ["gateway_main", "lead_reply"]) reply_tags: list[str] = Field(default_factory=_lead_reply_tags)
reply_source: str | None = "lead_to_gateway_main" reply_source: str | None = "lead_to_gateway_main"
class GatewayLeadBroadcastBoardResult(SQLModel): class GatewayLeadBroadcastBoardResult(SQLModel):
"""Per-board result entry for a lead broadcast operation."""
board_id: UUID board_id: UUID
lead_agent_id: UUID | None = None lead_agent_id: UUID | None = None
lead_agent_name: str | None = None lead_agent_name: str | None = None
@@ -44,6 +62,8 @@ class GatewayLeadBroadcastBoardResult(SQLModel):
class GatewayLeadBroadcastResponse(SQLModel): class GatewayLeadBroadcastResponse(SQLModel):
"""Aggregate response for a lead broadcast operation."""
ok: bool = True ok: bool = True
sent: int = 0 sent: int = 0
failed: int = 0 failed: int = 0
@@ -51,16 +71,21 @@ class GatewayLeadBroadcastResponse(SQLModel):
class GatewayMainAskUserRequest(SQLModel): class GatewayMainAskUserRequest(SQLModel):
"""Request payload for asking the end user via a main gateway agent."""
correlation_id: str | None = None correlation_id: str | None = None
content: NonEmptyStr content: NonEmptyStr
preferred_channel: str | None = None preferred_channel: str | None = None
# How the main agent should reply back into Mission Control (defaults interpreted by templates). # How the main agent should reply back into Mission Control
reply_tags: list[str] = Field(default_factory=lambda: ["gateway_main", "user_reply"]) # (defaults interpreted by templates).
reply_tags: list[str] = Field(default_factory=_user_reply_tags)
reply_source: str | None = "user_via_gateway_main" reply_source: str | None = "user_via_gateway_main"
class GatewayMainAskUserResponse(SQLModel): class GatewayMainAskUserResponse(SQLModel):
"""Response payload for user-question dispatch via gateway main agent."""
ok: bool = True ok: bool = True
board_id: UUID board_id: UUID
main_agent_id: UUID | None = None main_agent_id: UUID | None = None

View File

@@ -1,14 +1,17 @@
"""Schemas for gateway CRUD and template-sync API payloads."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from typing import Any from uuid import UUID # noqa: TCH003
from uuid import UUID
from pydantic import field_validator from pydantic import field_validator
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
class GatewayBase(SQLModel): class GatewayBase(SQLModel):
"""Shared gateway fields used across create/read payloads."""
name: str name: str
url: str url: str
main_session_key: str main_session_key: str
@@ -16,11 +19,14 @@ class GatewayBase(SQLModel):
class GatewayCreate(GatewayBase): class GatewayCreate(GatewayBase):
"""Payload for creating a gateway configuration."""
token: str | None = None token: str | None = None
@field_validator("token", mode="before") @field_validator("token", mode="before")
@classmethod @classmethod
def normalize_token(cls, value: Any) -> Any: def normalize_token(cls, value: object) -> str | None | object:
"""Normalize empty/whitespace tokens to `None`."""
if value is None: if value is None:
return None return None
if isinstance(value, str): if isinstance(value, str):
@@ -30,6 +36,8 @@ class GatewayCreate(GatewayBase):
class GatewayUpdate(SQLModel): class GatewayUpdate(SQLModel):
"""Payload for partial gateway updates."""
name: str | None = None name: str | None = None
url: str | None = None url: str | None = None
token: str | None = None token: str | None = None
@@ -38,7 +46,8 @@ class GatewayUpdate(SQLModel):
@field_validator("token", mode="before") @field_validator("token", mode="before")
@classmethod @classmethod
def normalize_token(cls, value: Any) -> Any: def normalize_token(cls, value: object) -> str | None | object:
"""Normalize empty/whitespace tokens to `None`."""
if value is None: if value is None:
return None return None
if isinstance(value, str): if isinstance(value, str):
@@ -48,6 +57,8 @@ class GatewayUpdate(SQLModel):
class GatewayRead(GatewayBase): class GatewayRead(GatewayBase):
"""Gateway payload returned from read endpoints."""
id: UUID id: UUID
organization_id: UUID organization_id: UUID
token: str | None = None token: str | None = None
@@ -56,6 +67,8 @@ class GatewayRead(GatewayBase):
class GatewayTemplatesSyncError(SQLModel): class GatewayTemplatesSyncError(SQLModel):
"""Per-agent error entry from a gateway template sync operation."""
agent_id: UUID | None = None agent_id: UUID | None = None
agent_name: str | None = None agent_name: str | None = None
board_id: UUID | None = None board_id: UUID | None = None
@@ -63,6 +76,8 @@ class GatewayTemplatesSyncError(SQLModel):
class GatewayTemplatesSyncResult(SQLModel): class GatewayTemplatesSyncResult(SQLModel):
"""Summary payload returned by gateway template sync endpoints."""
gateway_id: UUID gateway_id: UUID
include_main: bool include_main: bool
reset_sessions: bool reset_sessions: bool

View File

@@ -1,17 +1,23 @@
"""Dashboard metrics schemas for KPI and time-series API responses."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from typing import Literal from typing import Literal
from sqlmodel import SQLModel from sqlmodel import SQLModel
class DashboardSeriesPoint(SQLModel): class DashboardSeriesPoint(SQLModel):
"""Single numeric time-series point."""
period: datetime period: datetime
value: float value: float
class DashboardWipPoint(SQLModel): class DashboardWipPoint(SQLModel):
"""Work-in-progress point split by task status buckets."""
period: datetime period: datetime
inbox: int inbox: int
in_progress: int in_progress: int
@@ -19,28 +25,38 @@ class DashboardWipPoint(SQLModel):
class DashboardRangeSeries(SQLModel): class DashboardRangeSeries(SQLModel):
"""Series payload for a single range/bucket combination."""
range: Literal["24h", "7d"] range: Literal["24h", "7d"]
bucket: Literal["hour", "day"] bucket: Literal["hour", "day"]
points: list[DashboardSeriesPoint] points: list[DashboardSeriesPoint]
class DashboardWipRangeSeries(SQLModel): class DashboardWipRangeSeries(SQLModel):
"""WIP series payload for a single range/bucket combination."""
range: Literal["24h", "7d"] range: Literal["24h", "7d"]
bucket: Literal["hour", "day"] bucket: Literal["hour", "day"]
points: list[DashboardWipPoint] points: list[DashboardWipPoint]
class DashboardSeriesSet(SQLModel): class DashboardSeriesSet(SQLModel):
"""Primary vs comparison pair for generic series metrics."""
primary: DashboardRangeSeries primary: DashboardRangeSeries
comparison: DashboardRangeSeries comparison: DashboardRangeSeries
class DashboardWipSeriesSet(SQLModel): class DashboardWipSeriesSet(SQLModel):
"""Primary vs comparison pair for WIP status series metrics."""
primary: DashboardWipRangeSeries primary: DashboardWipRangeSeries
comparison: DashboardWipRangeSeries comparison: DashboardWipRangeSeries
class DashboardKpis(SQLModel): class DashboardKpis(SQLModel):
"""Topline dashboard KPI summary values."""
active_agents: int active_agents: int
tasks_in_progress: int tasks_in_progress: int
error_rate_pct: float error_rate_pct: float
@@ -48,6 +64,8 @@ class DashboardKpis(SQLModel):
class DashboardMetrics(SQLModel): class DashboardMetrics(SQLModel):
"""Complete dashboard metrics response payload."""
range: Literal["24h", "7d"] range: Literal["24h", "7d"]
generated_at: datetime generated_at: datetime
kpis: DashboardKpis kpis: DashboardKpis

View File

@@ -1,12 +1,16 @@
"""Schemas for organization, membership, and invite API payloads."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID from uuid import UUID # noqa: TCH003
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
class OrganizationRead(SQLModel): class OrganizationRead(SQLModel):
"""Organization payload returned by read endpoints."""
id: UUID id: UUID
name: str name: str
created_at: datetime created_at: datetime
@@ -14,14 +18,20 @@ class OrganizationRead(SQLModel):
class OrganizationCreate(SQLModel): class OrganizationCreate(SQLModel):
"""Payload for creating a new organization."""
name: str name: str
class OrganizationActiveUpdate(SQLModel): class OrganizationActiveUpdate(SQLModel):
"""Payload for switching the active organization context."""
organization_id: UUID organization_id: UUID
class OrganizationListItem(SQLModel): class OrganizationListItem(SQLModel):
"""Organization list row for current user memberships."""
id: UUID id: UUID
name: str name: str
role: str role: str
@@ -29,6 +39,8 @@ class OrganizationListItem(SQLModel):
class OrganizationUserRead(SQLModel): class OrganizationUserRead(SQLModel):
"""Embedded user fields included in organization member payloads."""
id: UUID id: UUID
email: str | None = None email: str | None = None
name: str | None = None name: str | None = None
@@ -36,6 +48,8 @@ class OrganizationUserRead(SQLModel):
class OrganizationMemberRead(SQLModel): class OrganizationMemberRead(SQLModel):
"""Organization member payload including board-level access overrides."""
id: UUID id: UUID
organization_id: UUID organization_id: UUID
user_id: UUID user_id: UUID
@@ -49,16 +63,22 @@ class OrganizationMemberRead(SQLModel):
class OrganizationMemberUpdate(SQLModel): class OrganizationMemberUpdate(SQLModel):
"""Payload for partial updates to organization member role."""
role: str | None = None role: str | None = None
class OrganizationBoardAccessSpec(SQLModel): class OrganizationBoardAccessSpec(SQLModel):
"""Board access specification used in member/invite mutation payloads."""
board_id: UUID board_id: UUID
can_read: bool = True can_read: bool = True
can_write: bool = False can_write: bool = False
class OrganizationBoardAccessRead(SQLModel): class OrganizationBoardAccessRead(SQLModel):
"""Board access payload returned from read endpoints."""
id: UUID id: UUID
board_id: UUID board_id: UUID
can_read: bool can_read: bool
@@ -68,12 +88,16 @@ class OrganizationBoardAccessRead(SQLModel):
class OrganizationMemberAccessUpdate(SQLModel): class OrganizationMemberAccessUpdate(SQLModel):
"""Payload for replacing organization member access permissions."""
all_boards_read: bool = False all_boards_read: bool = False
all_boards_write: bool = False all_boards_write: bool = False
board_access: list[OrganizationBoardAccessSpec] = Field(default_factory=list) board_access: list[OrganizationBoardAccessSpec] = Field(default_factory=list)
class OrganizationInviteCreate(SQLModel): class OrganizationInviteCreate(SQLModel):
"""Payload for creating an organization invite."""
invited_email: str invited_email: str
role: str = "member" role: str = "member"
all_boards_read: bool = False all_boards_read: bool = False
@@ -82,6 +106,8 @@ class OrganizationInviteCreate(SQLModel):
class OrganizationInviteRead(SQLModel): class OrganizationInviteRead(SQLModel):
"""Organization invite payload returned from read endpoints."""
id: UUID id: UUID
organization_id: UUID organization_id: UUID
invited_email: str invited_email: str
@@ -97,4 +123,6 @@ class OrganizationInviteRead(SQLModel):
class OrganizationInviteAccept(SQLModel): class OrganizationInviteAccept(SQLModel):
"""Payload for accepting an organization invite token."""
token: str token: str

View File

@@ -1,3 +1,5 @@
"""Shared pagination response type aliases used by API routes."""
from __future__ import annotations from __future__ import annotations
from typing import TypeVar from typing import TypeVar

View File

@@ -1,9 +1,13 @@
"""Schemas for souls-directory search and markdown fetch responses."""
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel from pydantic import BaseModel
class SoulsDirectorySoulRef(BaseModel): class SoulsDirectorySoulRef(BaseModel):
"""Reference metadata for a soul entry in the directory index."""
handle: str handle: str
slug: str slug: str
page_url: str page_url: str
@@ -11,10 +15,14 @@ class SoulsDirectorySoulRef(BaseModel):
class SoulsDirectorySearchResponse(BaseModel): class SoulsDirectorySearchResponse(BaseModel):
"""Response wrapper for directory search results."""
items: list[SoulsDirectorySoulRef] items: list[SoulsDirectorySoulRef]
class SoulsDirectoryMarkdownResponse(BaseModel): class SoulsDirectoryMarkdownResponse(BaseModel):
"""Response payload containing rendered markdown for a soul."""
handle: str handle: str
slug: str slug: str
content: str content: str

View File

@@ -1,18 +1,23 @@
"""Schemas for task CRUD and task comment API payloads."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from typing import Any, Literal, Self from typing import Literal, Self
from uuid import UUID from uuid import UUID # noqa: TCH003
from pydantic import field_validator, model_validator from pydantic import field_validator, model_validator
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
from app.schemas.common import NonEmptyStr from app.schemas.common import NonEmptyStr # noqa: TCH001
TaskStatus = Literal["inbox", "in_progress", "review", "done"] TaskStatus = Literal["inbox", "in_progress", "review", "done"]
STATUS_REQUIRED_ERROR = "status is required"
class TaskBase(SQLModel): class TaskBase(SQLModel):
"""Shared task fields used by task create/read payloads."""
title: str title: str
description: str | None = None description: str | None = None
status: TaskStatus = "inbox" status: TaskStatus = "inbox"
@@ -23,10 +28,14 @@ class TaskBase(SQLModel):
class TaskCreate(TaskBase): class TaskCreate(TaskBase):
"""Payload for creating a task."""
created_by_user_id: UUID | None = None created_by_user_id: UUID | None = None
class TaskUpdate(SQLModel): class TaskUpdate(SQLModel):
"""Payload for partial task updates."""
title: str | None = None title: str | None = None
description: str | None = None description: str | None = None
status: TaskStatus | None = None status: TaskStatus | None = None
@@ -38,7 +47,8 @@ class TaskUpdate(SQLModel):
@field_validator("comment", mode="before") @field_validator("comment", mode="before")
@classmethod @classmethod
def normalize_comment(cls, value: Any) -> Any: def normalize_comment(cls, value: object) -> object | None:
"""Normalize blank comment strings to `None`."""
if value is None: if value is None:
return None return None
if isinstance(value, str) and not value.strip(): if isinstance(value, str) and not value.strip():
@@ -47,12 +57,15 @@ class TaskUpdate(SQLModel):
@model_validator(mode="after") @model_validator(mode="after")
def validate_status(self) -> Self: def validate_status(self) -> Self:
"""Ensure explicitly supplied status is not null."""
if "status" in self.model_fields_set and self.status is None: if "status" in self.model_fields_set and self.status is None:
raise ValueError("status is required") raise ValueError(STATUS_REQUIRED_ERROR)
return self return self
class TaskRead(TaskBase): class TaskRead(TaskBase):
"""Task payload returned from read endpoints."""
id: UUID id: UUID
board_id: UUID | None board_id: UUID | None
created_by_user_id: UUID | None created_by_user_id: UUID | None
@@ -64,10 +77,14 @@ class TaskRead(TaskBase):
class TaskCommentCreate(SQLModel): class TaskCommentCreate(SQLModel):
"""Payload for creating a task comment."""
message: NonEmptyStr message: NonEmptyStr
class TaskCommentRead(SQLModel): class TaskCommentRead(SQLModel):
"""Task comment payload returned from read endpoints."""
id: UUID id: UUID
message: str | None message: str | None
agent_id: UUID | None agent_id: UUID | None

View File

@@ -1,11 +1,15 @@
"""User API schemas for create, update, and read operations."""
from __future__ import annotations from __future__ import annotations
from uuid import UUID from uuid import UUID # noqa: TCH003
from sqlmodel import SQLModel from sqlmodel import SQLModel
class UserBase(SQLModel): class UserBase(SQLModel):
"""Common user profile fields shared across user payload schemas."""
clerk_user_id: str clerk_user_id: str
email: str | None = None email: str | None = None
name: str | None = None name: str | None = None
@@ -17,10 +21,12 @@ class UserBase(SQLModel):
class UserCreate(UserBase): class UserCreate(UserBase):
pass """Payload used to create a user record."""
class UserUpdate(SQLModel): class UserUpdate(SQLModel):
"""Payload for partial user profile updates."""
name: str | None = None name: str | None = None
preferred_name: str | None = None preferred_name: str | None = None
pronouns: str | None = None pronouns: str | None = None
@@ -30,5 +36,7 @@ class UserUpdate(SQLModel):
class UserRead(UserBase): class UserRead(UserBase):
"""Full user payload returned by API responses."""
id: UUID id: UUID
is_super_admin: bool is_super_admin: bool

View File

@@ -1,25 +1,31 @@
"""Composite read models assembled for board and board-group views."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime # noqa: TCH003
from uuid import UUID from uuid import UUID # noqa: TCH003
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
from app.schemas.agents import AgentRead from app.schemas.agents import AgentRead # noqa: TCH001
from app.schemas.approvals import ApprovalRead from app.schemas.approvals import ApprovalRead # noqa: TCH001
from app.schemas.board_groups import BoardGroupRead from app.schemas.board_groups import BoardGroupRead # noqa: TCH001
from app.schemas.board_memory import BoardMemoryRead from app.schemas.board_memory import BoardMemoryRead # noqa: TCH001
from app.schemas.boards import BoardRead from app.schemas.boards import BoardRead # noqa: TCH001
from app.schemas.tasks import TaskRead from app.schemas.tasks import TaskRead
class TaskCardRead(TaskRead): class TaskCardRead(TaskRead):
"""Task read model enriched with assignee and approval counters."""
assignee: str | None = None assignee: str | None = None
approvals_count: int = 0 approvals_count: int = 0
approvals_pending_count: int = 0 approvals_pending_count: int = 0
class BoardSnapshot(SQLModel): class BoardSnapshot(SQLModel):
"""Aggregated board payload used by board snapshot endpoints."""
board: BoardRead board: BoardRead
tasks: list[TaskCardRead] tasks: list[TaskCardRead]
agents: list[AgentRead] agents: list[AgentRead]
@@ -29,6 +35,8 @@ class BoardSnapshot(SQLModel):
class BoardGroupTaskSummary(SQLModel): class BoardGroupTaskSummary(SQLModel):
"""Task summary row used inside board-group snapshot responses."""
id: UUID id: UUID
board_id: UUID board_id: UUID
board_name: str board_name: str
@@ -44,11 +52,15 @@ class BoardGroupTaskSummary(SQLModel):
class BoardGroupBoardSnapshot(SQLModel): class BoardGroupBoardSnapshot(SQLModel):
"""Board-level rollup embedded within a board-group snapshot."""
board: BoardRead board: BoardRead
task_counts: dict[str, int] = Field(default_factory=dict) task_counts: dict[str, int] = Field(default_factory=dict)
tasks: list[BoardGroupTaskSummary] = Field(default_factory=list) tasks: list[BoardGroupTaskSummary] = Field(default_factory=list)
class BoardGroupSnapshot(SQLModel): class BoardGroupSnapshot(SQLModel):
"""Top-level board-group snapshot response payload."""
group: BoardGroupRead | None = None group: BoardGroupRead | None = None
boards: list[BoardGroupBoardSnapshot] = Field(default_factory=list) boards: list[BoardGroupBoardSnapshot] = Field(default_factory=list)

View File

@@ -0,0 +1 @@
"""Business logic services for backend domain operations."""

View File

@@ -1,8 +1,13 @@
"""Utilities for recording normalized activity events."""
from __future__ import annotations from __future__ import annotations
from uuid import UUID from typing import TYPE_CHECKING
from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING:
from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.activity_events import ActivityEvent from app.models.activity_events import ActivityEvent
@@ -15,6 +20,7 @@ def record_activity(
agent_id: UUID | None = None, agent_id: UUID | None = None,
task_id: UUID | None = None, task_id: UUID | None = None,
) -> ActivityEvent: ) -> ActivityEvent:
"""Create and attach an activity event row to the current DB session."""
event = ActivityEvent( event = ActivityEvent(
event_type=event_type, event_type=event_type,
message=message, message=message,

View File

@@ -1,10 +1,16 @@
"""Access control helpers for admin-only operations."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import HTTPException, status from fastapi import HTTPException, status
from app.core.auth import AuthContext if TYPE_CHECKING:
from app.core.auth import AuthContext
def require_admin(auth: AuthContext) -> None: def require_admin(auth: AuthContext) -> None:
"""Raise HTTP 403 unless the authenticated actor is a user admin."""
if auth.actor_type != "user" or auth.user is None: if auth.actor_type != "user" or auth.user is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)

View File

@@ -1,21 +1,31 @@
"""Gateway-facing agent provisioning and cleanup helpers."""
# ruff: noqa: EM101, TRY003
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
import json import json
import re import re
from contextlib import suppress
from pathlib import Path from pathlib import Path
from typing import Any, cast from typing import TYPE_CHECKING, Any, cast
from uuid import uuid4 from uuid import uuid4
from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoescape from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoescape
from app.core.config import settings from app.core.config import settings
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, openclaw_call from app.integrations.openclaw_gateway import (
from app.models.agents import Agent OpenClawGatewayError,
from app.models.boards import Board ensure_session,
from app.models.gateways import Gateway openclaw_call,
from app.models.users import User )
if TYPE_CHECKING:
from app.models.agents import Agent
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.users import User
DEFAULT_HEARTBEAT_CONFIG = {"every": "10m", "target": "none"} DEFAULT_HEARTBEAT_CONFIG = {"every": "10m", "target": "none"}
DEFAULT_IDENTITY_PROFILE = { DEFAULT_IDENTITY_PROFILE = {
@@ -35,7 +45,8 @@ EXTRA_IDENTITY_PROFILE_FIELDS = {
"verbosity": "identity_verbosity", "verbosity": "identity_verbosity",
"output_format": "identity_output_format", "output_format": "identity_output_format",
"update_cadence": "identity_update_cadence", "update_cadence": "identity_update_cadence",
# Per-agent charter (optional). Used to give agents a "purpose in life" and a distinct vibe. # Per-agent charter (optional).
# Used to give agents a "purpose in life" and a distinct vibe.
"purpose": "identity_purpose", "purpose": "identity_purpose",
"personality": "identity_personality", "personality": "identity_personality",
"custom_instructions": "identity_custom_instructions", "custom_instructions": "identity_custom_instructions",
@@ -54,11 +65,11 @@ DEFAULT_GATEWAY_FILES = frozenset(
"BOOT.md", "BOOT.md",
"BOOTSTRAP.md", "BOOTSTRAP.md",
"MEMORY.md", "MEMORY.md",
} },
) )
# These files are intended to evolve within the agent workspace. Provision them if missing, # These files are intended to evolve within the agent workspace.
# but avoid overwriting existing content during updates. # Provision them if missing, but avoid overwriting existing content during updates.
# #
# Examples: # Examples:
# - SELF.md: evolving identity/preferences # - SELF.md: evolving identity/preferences
@@ -68,6 +79,7 @@ PRESERVE_AGENT_EDITABLE_FILES = frozenset({"SELF.md", "USER.md", "MEMORY.md"})
HEARTBEAT_LEAD_TEMPLATE = "HEARTBEAT_LEAD.md" HEARTBEAT_LEAD_TEMPLATE = "HEARTBEAT_LEAD.md"
HEARTBEAT_AGENT_TEMPLATE = "HEARTBEAT_AGENT.md" HEARTBEAT_AGENT_TEMPLATE = "HEARTBEAT_AGENT.md"
_SESSION_KEY_PARTS_MIN = 2
MAIN_TEMPLATE_MAP = { MAIN_TEMPLATE_MAP = {
"AGENTS.md": "MAIN_AGENTS.md", "AGENTS.md": "MAIN_AGENTS.md",
"HEARTBEAT.md": "MAIN_HEARTBEAT.md", "HEARTBEAT.md": "MAIN_HEARTBEAT.md",
@@ -97,13 +109,13 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None:
if not value.startswith("agent:"): if not value.startswith("agent:"):
return None return None
parts = value.split(":") parts = value.split(":")
if len(parts) < 2: if len(parts) < _SESSION_KEY_PARTS_MIN:
return None return None
agent_id = parts[1].strip() agent_id = parts[1].strip()
return agent_id or None return agent_id or None
def _extract_agent_id(payload: object) -> str | None: def _extract_agent_id(payload: object) -> str | None: # noqa: C901
def _from_list(items: object) -> str | None: def _from_list(items: object) -> str | None:
if not isinstance(items, list): if not isinstance(items, list):
return None return None
@@ -137,7 +149,7 @@ def _agent_key(agent: Agent) -> str:
session_key = agent.openclaw_session_id or "" session_key = agent.openclaw_session_id or ""
if session_key.startswith("agent:"): if session_key.startswith("agent:"):
parts = session_key.split(":") parts = session_key.split(":")
if len(parts) >= 2 and parts[1]: if len(parts) >= _SESSION_KEY_PARTS_MIN and parts[1]:
return parts[1] return parts[1]
return _slugify(agent.name) return _slugify(agent.name)
@@ -183,14 +195,14 @@ def _ensure_workspace_file(
if not workspace_path or not name: if not workspace_path or not name:
return return
# Only write to a dedicated, explicitly-configured local directory. # Only write to a dedicated, explicitly-configured local directory.
# Using `gateway.workspace_root` directly here is unsafe (and CodeQL correctly flags it) # Using `gateway.workspace_root` directly here is unsafe.
# because it is a DB-backed config value. # CodeQL correctly flags that value because it is DB-backed config.
base_root = (settings.local_agent_workspace_root or "").strip() base_root = (settings.local_agent_workspace_root or "").strip()
if not base_root: if not base_root:
return return
base = Path(base_root).expanduser() base = Path(base_root).expanduser()
# Derive a stable, safe directory name from the (potentially untrusted) workspace path. # Derive a stable, safe directory name from the untrusted workspace path.
# This prevents path traversal and avoids writing to arbitrary locations. # This prevents path traversal and avoids writing to arbitrary locations.
digest = hashlib.sha256(workspace_path.encode("utf-8")).hexdigest()[:16] digest = hashlib.sha256(workspace_path.encode("utf-8")).hexdigest()[:16]
root = base / f"gateway-workspace-{digest}" root = base / f"gateway-workspace-{digest}"
@@ -345,12 +357,14 @@ async def _supported_gateway_files(config: GatewayClientConfig) -> set[str]:
default_id = None default_id = None
if isinstance(agents_payload, dict): if isinstance(agents_payload, dict):
agents = list(agents_payload.get("agents") or []) agents = list(agents_payload.get("agents") or [])
default_id = agents_payload.get("defaultId") or agents_payload.get("default_id") default_id = agents_payload.get("defaultId") or agents_payload.get(
"default_id",
)
agent_id = default_id or (agents[0].get("id") if agents else None) agent_id = default_id or (agents[0].get("id") if agents else None)
if not agent_id: if not agent_id:
return set(DEFAULT_GATEWAY_FILES) return set(DEFAULT_GATEWAY_FILES)
files_payload = await openclaw_call( files_payload = await openclaw_call(
"agents.files.list", {"agentId": agent_id}, config=config "agents.files.list", {"agentId": agent_id}, config=config,
) )
if isinstance(files_payload, dict): if isinstance(files_payload, dict):
files = files_payload.get("files") or [] files = files_payload.get("files") or []
@@ -374,10 +388,12 @@ async def _reset_session(session_key: str, config: GatewayClientConfig) -> None:
async def _gateway_agent_files_index( async def _gateway_agent_files_index(
agent_id: str, config: GatewayClientConfig agent_id: str, config: GatewayClientConfig,
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
try: try:
payload = await openclaw_call("agents.files.list", {"agentId": agent_id}, config=config) payload = await openclaw_call(
"agents.files.list", {"agentId": agent_id}, config=config,
)
if isinstance(payload, dict): if isinstance(payload, dict):
files = payload.get("files") or [] files = payload.get("files") or []
index: dict[str, dict[str, Any]] = {} index: dict[str, dict[str, Any]] = {}
@@ -420,21 +436,25 @@ def _render_agent_files(
) )
heartbeat_path = _templates_root() / heartbeat_template heartbeat_path = _templates_root() / heartbeat_template
if heartbeat_path.exists(): if heartbeat_path.exists():
rendered[name] = env.get_template(heartbeat_template).render(**context).strip() rendered[name] = (
env.get_template(heartbeat_template).render(**context).strip()
)
continue continue
override = overrides.get(name) override = overrides.get(name)
if override: if override:
rendered[name] = env.from_string(override).render(**context).strip() rendered[name] = env.from_string(override).render(**context).strip()
continue continue
template_name = ( template_name = (
template_overrides[name] if template_overrides and name in template_overrides else name template_overrides[name]
if template_overrides and name in template_overrides
else name
) )
path = _templates_root() / template_name path = _templates_root() / template_name
if path.exists(): if path.exists():
rendered[name] = env.get_template(template_name).render(**context).strip() rendered[name] = env.get_template(template_name).render(**context).strip()
continue continue
if name == "MEMORY.md": if name == "MEMORY.md":
# Back-compat fallback for existing gateways that don't ship a MEMORY.md template. # Back-compat fallback for gateways that do not ship MEMORY.md.
rendered[name] = "# MEMORY\n\nBootstrap pending.\n" rendered[name] = "# MEMORY\n\nBootstrap pending.\n"
continue continue
rendered[name] = "" rendered[name] = ""
@@ -487,7 +507,9 @@ async def _patch_gateway_agent_list(
else: else:
new_list.append(entry) new_list.append(entry)
if not updated: if not updated:
new_list.append({"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat}) new_list.append(
{"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat},
)
patch = {"agents": {"list": new_list}} patch = {"agents": {"list": new_list}}
params = {"raw": json.dumps(patch)} params = {"raw": json.dumps(patch)}
@@ -496,7 +518,7 @@ async def _patch_gateway_agent_list(
await openclaw_call("config.patch", params, config=config) await openclaw_call("config.patch", params, config=config)
async def patch_gateway_agent_heartbeats( async def patch_gateway_agent_heartbeats( # noqa: C901
gateway: Gateway, gateway: Gateway,
*, *,
entries: list[tuple[str, str, dict[str, Any]]], entries: list[tuple[str, str, dict[str, Any]]],
@@ -521,7 +543,8 @@ async def patch_gateway_agent_heartbeats(
raise OpenClawGatewayError("config agents.list is not a list") raise OpenClawGatewayError("config agents.list is not a list")
entry_by_id: dict[str, tuple[str, dict[str, Any]]] = { entry_by_id: dict[str, tuple[str, dict[str, Any]]] = {
agent_id: (workspace_path, heartbeat) for agent_id, workspace_path, heartbeat in entries agent_id: (workspace_path, heartbeat)
for agent_id, workspace_path, heartbeat in entries
} }
updated_ids: set[str] = set() updated_ids: set[str] = set()
@@ -544,7 +567,9 @@ async def patch_gateway_agent_heartbeats(
for agent_id, (workspace_path, heartbeat) in entry_by_id.items(): for agent_id, (workspace_path, heartbeat) in entry_by_id.items():
if agent_id in updated_ids: if agent_id in updated_ids:
continue continue
new_list.append({"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat}) new_list.append(
{"id": agent_id, "workspace": workspace_path, "heartbeat": heartbeat},
)
patch = {"agents": {"list": new_list}} patch = {"agents": {"list": new_list}}
params = {"raw": json.dumps(patch)} params = {"raw": json.dumps(patch)}
@@ -585,7 +610,9 @@ async def _remove_gateway_agent_list(
raise OpenClawGatewayError("config agents.list is not a list") raise OpenClawGatewayError("config agents.list is not a list")
new_list = [ new_list = [
entry for entry in lst if not (isinstance(entry, dict) and entry.get("id") == agent_id) entry
for entry in lst
if not (isinstance(entry, dict) and entry.get("id") == agent_id)
] ]
if len(new_list) == len(lst): if len(new_list) == len(lst):
return return
@@ -616,7 +643,7 @@ async def _get_gateway_agent_entry(
return None return None
async def provision_agent( async def provision_agent( # noqa: C901, PLR0912, PLR0913
agent: Agent, agent: Agent,
board: Board, board: Board,
gateway: Gateway, gateway: Gateway,
@@ -627,6 +654,7 @@ async def provision_agent(
force_bootstrap: bool = False, force_bootstrap: bool = False,
reset_session: bool = False, reset_session: bool = False,
) -> None: ) -> None:
"""Provision or update a regular board agent workspace."""
if not gateway.url: if not gateway.url:
return return
if not gateway.workspace_root: if not gateway.workspace_root:
@@ -665,11 +693,9 @@ async def provision_agent(
content = rendered.get(name) content = rendered.get(name)
if not content: if not content:
continue continue
try: with suppress(OSError):
_ensure_workspace_file(workspace_path, name, content, overwrite=False)
except OSError:
# Local workspace may not be writable/available; fall back to gateway API. # Local workspace may not be writable/available; fall back to gateway API.
pass _ensure_workspace_file(workspace_path, name, content, overwrite=False)
for name, content in rendered.items(): for name, content in rendered.items():
if content == "": if content == "":
continue continue
@@ -694,7 +720,7 @@ async def provision_agent(
await _reset_session(session_key, client_config) await _reset_session(session_key, client_config)
async def provision_main_agent( async def provision_main_agent( # noqa: C901, PLR0912, PLR0913
agent: Agent, agent: Agent,
gateway: Gateway, gateway: Gateway,
auth_token: str, auth_token: str,
@@ -704,12 +730,15 @@ async def provision_main_agent(
force_bootstrap: bool = False, force_bootstrap: bool = False,
reset_session: bool = False, reset_session: bool = False,
) -> None: ) -> None:
"""Provision or update the gateway main agent workspace."""
if not gateway.url: if not gateway.url:
return return
if not gateway.main_session_key: if not gateway.main_session_key:
raise ValueError("gateway main_session_key is required") raise ValueError("gateway main_session_key is required")
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
await ensure_session(gateway.main_session_key, config=client_config, label="Main Agent") await ensure_session(
gateway.main_session_key, config=client_config, label="Main Agent",
)
agent_id = await _gateway_default_agent_id( agent_id = await _gateway_default_agent_id(
client_config, client_config,
@@ -763,6 +792,7 @@ async def cleanup_agent(
agent: Agent, agent: Agent,
gateway: Gateway, gateway: Gateway,
) -> str | None: ) -> str | None:
"""Remove an agent from gateway config and delete its session."""
if not gateway.url: if not gateway.url:
return None return None
if not gateway.workspace_root: if not gateway.workspace_root:

View File

@@ -1,30 +1,41 @@
"""Helpers for ensuring each board has a provisioned lead agent."""
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import TYPE_CHECKING, Any
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_tokens import generate_agent_token, hash_agent_token from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.time import utcnow from app.core.time import utcnow
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent from app.models.agents import Agent
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.users import User
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_agent
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.users import User
def lead_session_key(board: Board) -> str: def lead_session_key(board: Board) -> str:
"""Return the deterministic main session key for a board lead agent."""
return f"agent:lead-{board.id}:main" return f"agent:lead-{board.id}:main"
def lead_agent_name(_: Board) -> str: def lead_agent_name(_: Board) -> str:
"""Return the default display name for board lead agents."""
return "Lead Agent" return "Lead Agent"
async def ensure_board_lead_agent( async def ensure_board_lead_agent( # noqa: PLR0913
session: AsyncSession, session: AsyncSession,
*, *,
board: Board, board: Board,
@@ -35,11 +46,12 @@ async def ensure_board_lead_agent(
identity_profile: dict[str, str] | None = None, identity_profile: dict[str, str] | None = None,
action: str = "provision", action: str = "provision",
) -> tuple[Agent, bool]: ) -> tuple[Agent, bool]:
"""Ensure a board has a lead agent; return `(agent, created)`."""
existing = ( existing = (
await session.exec( await session.exec(
select(Agent) select(Agent)
.where(Agent.board_id == board.id) .where(Agent.board_id == board.id)
.where(col(Agent.is_board_lead).is_(True)) .where(col(Agent.is_board_lead).is_(True)),
) )
).first() ).first()
if existing: if existing:
@@ -66,7 +78,11 @@ async def ensure_board_lead_agent(
} }
if identity_profile: if identity_profile:
merged_identity_profile.update( merged_identity_profile.update(
{key: value.strip() for key, value in identity_profile.items() if value.strip()} {
key: value.strip()
for key, value in identity_profile.items()
if value.strip()
},
) )
agent = Agent( agent = Agent(
@@ -89,11 +105,16 @@ async def ensure_board_lead_agent(
try: try:
await provision_agent(agent, board, gateway, raw_token, user, action=action) await provision_agent(agent, board, gateway, raw_token, user, action=action)
if agent.openclaw_session_id: if agent.openclaw_session_id:
await ensure_session(agent.openclaw_session_id, config=config, label=agent.name) await ensure_session(
agent.openclaw_session_id,
config=config,
label=agent.name,
)
await send_message( await send_message(
( (
f"Hello {agent.name}. Your workspace has been provisioned.\n\n" f"Hello {agent.name}. Your workspace has been provisioned.\n\n"
"Start the agent, run BOOT.md, and if BOOTSTRAP.md exists run it once " "Start the agent, run BOOT.md, and if BOOTSTRAP.md exists run "
"it once "
"then delete it. Begin heartbeats after startup." "then delete it. Begin heartbeats after startup."
), ),
session_key=agent.openclaw_session_id, session_key=agent.openclaw_session_id,

View File

@@ -1,17 +1,17 @@
"""Helpers for assembling denormalized board snapshot response payloads."""
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from datetime import timedelta
from uuid import UUID from typing import TYPE_CHECKING
from sqlalchemy import case, func from sqlalchemy import case, func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.time import utcnow from app.core.time import utcnow
from app.models.agents import Agent from app.models.agents import Agent
from app.models.approvals import Approval from app.models.approvals import Approval
from app.models.board_memory import BoardMemory from app.models.board_memory import BoardMemory
from app.models.boards import Board
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.models.tasks import Task from app.models.tasks import Task
from app.schemas.agents import AgentRead from app.schemas.agents import AgentRead
@@ -25,6 +25,13 @@ from app.services.task_dependencies import (
dependency_status_by_id, dependency_status_by_id,
) )
if TYPE_CHECKING:
from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.boards import Board
OFFLINE_AFTER = timedelta(minutes=10) OFFLINE_AFTER = timedelta(minutes=10)
@@ -48,9 +55,15 @@ def _agent_to_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
model = AgentRead.model_validate(agent, from_attributes=True) model = AgentRead.model_validate(agent, from_attributes=True)
computed_status = _computed_agent_status(agent) computed_status = _computed_agent_status(agent)
is_gateway_main = bool( is_gateway_main = bool(
agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys agent.openclaw_session_id
and agent.openclaw_session_id in main_session_keys,
)
return model.model_copy(
update={
"status": computed_status,
"is_gateway_main": is_gateway_main,
},
) )
return model.model_copy(update={"status": computed_status, "is_gateway_main": is_gateway_main})
def _memory_to_read(memory: BoardMemory) -> BoardMemoryRead: def _memory_to_read(memory: BoardMemory) -> BoardMemoryRead:
@@ -72,7 +85,9 @@ def _task_to_card(
card = TaskCardRead.model_validate(task, from_attributes=True) card = TaskCardRead.model_validate(task, from_attributes=True)
approvals_count, approvals_pending_count = counts_by_task_id.get(task.id, (0, 0)) approvals_count, approvals_pending_count = counts_by_task_id.get(task.id, (0, 0))
assignee = ( assignee = (
agent_name_by_id.get(task.assigned_agent_id) if task.assigned_agent_id is not None else None agent_name_by_id.get(task.assigned_agent_id)
if task.assigned_agent_id
else None
) )
depends_on_task_ids = deps_by_task_id.get(task.id, []) depends_on_task_ids = deps_by_task_id.get(task.id, [])
blocked_by_task_ids = blocked_by_dependency_ids( blocked_by_task_ids = blocked_by_dependency_ids(
@@ -89,21 +104,26 @@ def _task_to_card(
"depends_on_task_ids": depends_on_task_ids, "depends_on_task_ids": depends_on_task_ids,
"blocked_by_task_ids": blocked_by_task_ids, "blocked_by_task_ids": blocked_by_task_ids,
"is_blocked": bool(blocked_by_task_ids), "is_blocked": bool(blocked_by_task_ids),
} },
) )
async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnapshot: async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnapshot:
"""Build a board snapshot with tasks, agents, approvals, and chat history."""
board_read = BoardRead.model_validate(board, from_attributes=True) board_read = BoardRead.model_validate(board, from_attributes=True)
tasks = list( tasks = list(
await Task.objects.filter_by(board_id=board.id) await Task.objects.filter_by(board_id=board.id)
.order_by(col(Task.created_at).desc()) .order_by(col(Task.created_at).desc())
.all(session) .all(session),
) )
task_ids = [task.id for task in tasks] task_ids = [task.id for task in tasks]
deps_by_task_id = await dependency_ids_by_task_id(session, board_id=board.id, task_ids=task_ids) deps_by_task_id = await dependency_ids_by_task_id(
session,
board_id=board.id,
task_ids=task_ids,
)
all_dependency_ids: list[UUID] = [] all_dependency_ids: list[UUID] = []
for values in deps_by_task_id.values(): for values in deps_by_task_id.values():
all_dependency_ids.extend(values) all_dependency_ids.extend(values)
@@ -127,9 +147,9 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
await session.exec( await session.exec(
select(func.count(col(Approval.id))) select(func.count(col(Approval.id)))
.where(col(Approval.board_id) == board.id) .where(col(Approval.board_id) == board.id)
.where(col(Approval.status) == "pending") .where(col(Approval.status) == "pending"),
) ),
).one() ).one(),
) )
approvals = ( approvals = (
@@ -146,12 +166,14 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
select( select(
col(Approval.task_id), col(Approval.task_id),
func.count(col(Approval.id)).label("total"), func.count(col(Approval.id)).label("total"),
func.sum(case((col(Approval.status) == "pending", 1), else_=0)).label("pending"), func.sum(
case((col(Approval.status) == "pending", 1), else_=0),
).label("pending"),
) )
.where(col(Approval.board_id) == board.id) .where(col(Approval.board_id) == board.id)
.where(col(Approval.task_id).is_not(None)) .where(col(Approval.task_id).is_not(None))
.group_by(col(Approval.task_id)) .group_by(col(Approval.task_id)),
) ),
) )
for task_id, total, pending in rows: for task_id, total, pending in rows:
if task_id is None: if task_id is None:

View File

@@ -1,26 +1,33 @@
"""Policy helpers for lead-agent approval and planning decisions."""
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
from typing import Mapping from typing import Mapping
CONFIDENCE_THRESHOLD = 80 CONFIDENCE_THRESHOLD = 80
MIN_PLANNING_SIGNALS = 2
def compute_confidence(rubric_scores: Mapping[str, int]) -> int: def compute_confidence(rubric_scores: Mapping[str, int]) -> int:
"""Compute aggregate confidence from rubric score components."""
return int(sum(rubric_scores.values())) return int(sum(rubric_scores.values()))
def approval_required(*, confidence: int, is_external: bool, is_risky: bool) -> bool: def approval_required(*, confidence: int, is_external: bool, is_risky: bool) -> bool:
"""Return whether an action must go through explicit approval."""
return is_external or is_risky or confidence < CONFIDENCE_THRESHOLD return is_external or is_risky or confidence < CONFIDENCE_THRESHOLD
def infer_planning(signals: Mapping[str, bool]) -> bool: def infer_planning(signals: Mapping[str, bool]) -> bool:
"""Infer planning intent from boolean heuristic signals."""
# Require at least two planning signals to avoid spam on general boards. # Require at least two planning signals to avoid spam on general boards.
truthy = [key for key, value in signals.items() if value] truthy = [key for key, value in signals.items() if value]
return len(truthy) >= 2 return len(truthy) >= MIN_PLANNING_SIGNALS
def task_fingerprint(title: str, description: str | None, board_id: str) -> str: def task_fingerprint(title: str, description: str | None, board_id: str) -> str:
"""Build a stable hash key for deduplicating similar board tasks."""
normalized_title = title.strip().lower() normalized_title = title.strip().lower()
normalized_desc = (description or "").strip().lower() normalized_desc = (description or "").strip().lower()
seed = f"{board_id}::{normalized_title}::{normalized_desc}" seed = f"{board_id}::{normalized_title}::{normalized_desc}"

View File

@@ -1,18 +1,24 @@
"""Helpers for extracting and matching `@mention` tokens in text."""
from __future__ import annotations from __future__ import annotations
import re import re
from typing import TYPE_CHECKING
from app.models.agents import Agent if TYPE_CHECKING:
from app.models.agents import Agent
# Mention tokens are single, space-free words (e.g. "@alex", "@lead"). # Mention tokens are single, space-free words (e.g. "@alex", "@lead").
MENTION_PATTERN = re.compile(r"@([A-Za-z][\w-]{0,31})") MENTION_PATTERN = re.compile(r"@([A-Za-z][\w-]{0,31})")
def extract_mentions(message: str) -> set[str]: def extract_mentions(message: str) -> set[str]:
"""Extract normalized mention handles from a message body."""
return {match.group(1).lower() for match in MENTION_PATTERN.finditer(message)} return {match.group(1).lower() for match in MENTION_PATTERN.finditer(message)}
def matches_agent_mention(agent: Agent, mentions: set[str]) -> bool: def matches_agent_mention(agent: Agent, mentions: set[str]) -> bool:
"""Return whether a mention set targets the provided agent."""
if not mentions: if not mentions:
return False return False

View File

@@ -1,14 +1,14 @@
"""Organization membership and board-access service helpers."""
# ruff: noqa: D101, D103
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable from typing import TYPE_CHECKING, Iterable
from uuid import UUID
from fastapi import HTTPException, status from fastapi import HTTPException, status
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.time import utcnow from app.core.time import utcnow
from app.db import crud from app.db import crud
@@ -19,7 +19,17 @@ from app.models.organization_invites import OrganizationInvite
from app.models.organization_members import OrganizationMember from app.models.organization_members import OrganizationMember
from app.models.organizations import Organization from app.models.organizations import Organization
from app.models.users import User from app.models.users import User
from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate
if TYPE_CHECKING:
from uuid import UUID
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel.ext.asyncio.session import AsyncSession
from app.schemas.organizations import (
OrganizationBoardAccessSpec,
OrganizationMemberAccessUpdate,
)
DEFAULT_ORG_NAME = "Personal" DEFAULT_ORG_NAME = "Personal"
ADMIN_ROLES = {"owner", "admin"} ADMIN_ROLES = {"owner", "admin"}
@@ -63,7 +73,9 @@ async def get_member(
).first(session) ).first(session)
async def get_first_membership(session: AsyncSession, user_id: UUID) -> OrganizationMember | None: async def get_first_membership(
session: AsyncSession, user_id: UUID,
) -> OrganizationMember | None:
return ( return (
await OrganizationMember.objects.filter_by(user_id=user_id) await OrganizationMember.objects.filter_by(user_id=user_id)
.order_by(col(OrganizationMember.created_at).asc()) .order_by(col(OrganizationMember.created_at).asc())
@@ -79,7 +91,9 @@ async def set_active_organization(
) -> OrganizationMember: ) -> OrganizationMember:
member = await get_member(session, user_id=user.id, organization_id=organization_id) member = await get_member(session, user_id=user.id, organization_id=organization_id)
if member is None: if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access") raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
)
if user.active_organization_id != organization_id: if user.active_organization_id != organization_id:
user.active_organization_id = organization_id user.active_organization_id = organization_id
session.add(user) session.add(user)
@@ -154,9 +168,10 @@ async def accept_invite(
access_rows = list( access_rows = list(
await session.exec( await session.exec(
select(OrganizationInviteBoardAccess).where( select(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id col(OrganizationInviteBoardAccess.organization_invite_id)
) == invite.id,
) ),
),
) )
for row in access_rows: for row in access_rows:
session.add( session.add(
@@ -167,7 +182,7 @@ async def accept_invite(
can_write=row.can_write, can_write=row.can_write,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
) ),
) )
invite.accepted_by_user_id = user.id invite.accepted_by_user_id = user.id
@@ -182,7 +197,9 @@ async def accept_invite(
return member return member
async def ensure_member_for_user(session: AsyncSession, user: User) -> OrganizationMember: async def ensure_member_for_user(
session: AsyncSession, user: User,
) -> OrganizationMember:
existing = await get_active_membership(session, user) existing = await get_active_membership(session, user)
if existing is not None: if existing is not None:
return existing return existing
@@ -196,7 +213,9 @@ async def ensure_member_for_user(session: AsyncSession, user: User) -> Organizat
now = utcnow() now = utcnow()
member_count = ( member_count = (
await session.exec( await session.exec(
select(func.count()).where(col(OrganizationMember.organization_id) == org.id) select(func.count()).where(
col(OrganizationMember.organization_id) == org.id,
),
) )
).one() ).one()
is_first = int(member_count or 0) == 0 is_first = int(member_count or 0) == 0
@@ -257,30 +276,40 @@ async def require_board_access(
board: Board, board: Board,
write: bool, write: bool,
) -> OrganizationMember: ) -> OrganizationMember:
member = await get_member(session, user_id=user.id, organization_id=board.organization_id) member = await get_member(
session, user_id=user.id, organization_id=board.organization_id,
)
if member is None: if member is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No org access") raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
)
if not await has_board_access(session, member=member, board=board, write=write): if not await has_board_access(session, member=member, board=board, write=write):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied") raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied",
)
return member return member
def board_access_filter(member: OrganizationMember, *, write: bool) -> ColumnElement[bool]: def board_access_filter(
member: OrganizationMember, *, write: bool,
) -> ColumnElement[bool]:
if write and member_all_boards_write(member): if write and member_all_boards_write(member):
return col(Board.organization_id) == member.organization_id return col(Board.organization_id) == member.organization_id
if not write and member_all_boards_read(member): if not write and member_all_boards_read(member):
return col(Board.organization_id) == member.organization_id return col(Board.organization_id) == member.organization_id
access_stmt = select(OrganizationBoardAccess.board_id).where( access_stmt = select(OrganizationBoardAccess.board_id).where(
col(OrganizationBoardAccess.organization_member_id) == member.id col(OrganizationBoardAccess.organization_member_id) == member.id,
) )
if write: if write:
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True)) access_stmt = access_stmt.where(
col(OrganizationBoardAccess.can_write).is_(True),
)
else: else:
access_stmt = access_stmt.where( access_stmt = access_stmt.where(
or_( or_(
col(OrganizationBoardAccess.can_read).is_(True), col(OrganizationBoardAccess.can_read).is_(True),
col(OrganizationBoardAccess.can_write).is_(True), col(OrganizationBoardAccess.can_write).is_(True),
) ),
) )
return col(Board.id).in_(access_stmt) return col(Board.id).in_(access_stmt)
@@ -295,21 +324,25 @@ async def list_accessible_board_ids(
not write and member_all_boards_read(member) not write and member_all_boards_read(member)
): ):
ids = await session.exec( ids = await session.exec(
select(Board.id).where(col(Board.organization_id) == member.organization_id) select(Board.id).where(
col(Board.organization_id) == member.organization_id,
),
) )
return list(ids) return list(ids)
access_stmt = select(OrganizationBoardAccess.board_id).where( access_stmt = select(OrganizationBoardAccess.board_id).where(
col(OrganizationBoardAccess.organization_member_id) == member.id col(OrganizationBoardAccess.organization_member_id) == member.id,
) )
if write: if write:
access_stmt = access_stmt.where(col(OrganizationBoardAccess.can_write).is_(True)) access_stmt = access_stmt.where(
col(OrganizationBoardAccess.can_write).is_(True),
)
else: else:
access_stmt = access_stmt.where( access_stmt = access_stmt.where(
or_( or_(
col(OrganizationBoardAccess.can_read).is_(True), col(OrganizationBoardAccess.can_read).is_(True),
col(OrganizationBoardAccess.can_write).is_(True), col(OrganizationBoardAccess.can_write).is_(True),
) ),
) )
board_ids = await session.exec(access_stmt) board_ids = await session.exec(access_stmt)
return list(board_ids) return list(board_ids)
@@ -337,18 +370,17 @@ async def apply_member_access_update(
if update.all_boards_read or update.all_boards_write: if update.all_boards_read or update.all_boards_write:
return return
rows: list[OrganizationBoardAccess] = [] rows = [
for entry in update.board_access: OrganizationBoardAccess(
rows.append( organization_member_id=member.id,
OrganizationBoardAccess( board_id=entry.board_id,
organization_member_id=member.id, can_read=entry.can_read,
board_id=entry.board_id, can_write=entry.can_write,
can_read=entry.can_read, created_at=now,
can_write=entry.can_write, updated_at=now,
created_at=now,
updated_at=now,
)
) )
for entry in update.board_access
]
session.add_all(rows) session.add_all(rows)
@@ -367,18 +399,17 @@ async def apply_invite_board_access(
if invite.all_boards_read or invite.all_boards_write: if invite.all_boards_read or invite.all_boards_write:
return return
now = utcnow() now = utcnow()
rows: list[OrganizationInviteBoardAccess] = [] rows = [
for entry in entries: OrganizationInviteBoardAccess(
rows.append( organization_invite_id=invite.id,
OrganizationInviteBoardAccess( board_id=entry.board_id,
organization_invite_id=invite.id, can_read=entry.can_read,
board_id=entry.board_id, can_write=entry.can_write,
can_read=entry.can_read, created_at=now,
can_write=entry.can_write, updated_at=now,
created_at=now,
updated_at=now,
)
) )
for entry in entries
]
session.add_all(rows) session.add_all(rows)
@@ -423,9 +454,9 @@ async def apply_invite_to_member(
access_rows = list( access_rows = list(
await session.exec( await session.exec(
select(OrganizationInviteBoardAccess).where( select(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
) ),
) ),
) )
for row in access_rows: for row in access_rows:
existing = ( existing = (
@@ -433,7 +464,7 @@ async def apply_invite_to_member(
select(OrganizationBoardAccess).where( select(OrganizationBoardAccess).where(
col(OrganizationBoardAccess.organization_member_id) == member.id, col(OrganizationBoardAccess.organization_member_id) == member.id,
col(OrganizationBoardAccess.board_id) == row.board_id, col(OrganizationBoardAccess.board_id) == row.board_id,
) ),
) )
).first() ).first()
can_write = bool(row.can_write) can_write = bool(row.can_write)
@@ -447,7 +478,7 @@ async def apply_invite_to_member(
can_write=can_write, can_write=can_write,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
) ),
) )
else: else:
existing.can_read = bool(existing.can_read or can_read) existing.can_read = bool(existing.can_read or can_read)

View File

@@ -1,3 +1,5 @@
"""Service helpers for querying and caching souls.directory content."""
from __future__ import annotations from __future__ import annotations
import time import time
@@ -11,33 +13,41 @@ SOULS_DIRECTORY_BASE_URL: Final[str] = "https://souls.directory"
SOULS_DIRECTORY_SITEMAP_URL: Final[str] = f"{SOULS_DIRECTORY_BASE_URL}/sitemap.xml" SOULS_DIRECTORY_SITEMAP_URL: Final[str] = f"{SOULS_DIRECTORY_BASE_URL}/sitemap.xml"
_SITEMAP_TTL_SECONDS: Final[int] = 60 * 60 _SITEMAP_TTL_SECONDS: Final[int] = 60 * 60
_SOUL_URL_MIN_PARTS: Final[int] = 6
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class SoulRef: class SoulRef:
"""Handle/slug reference pair for a soul entry."""
handle: str handle: str
slug: str slug: str
@property @property
def page_url(self) -> str: def page_url(self) -> str:
"""Return the canonical page URL for this soul."""
return f"{SOULS_DIRECTORY_BASE_URL}/souls/{self.handle}/{self.slug}" return f"{SOULS_DIRECTORY_BASE_URL}/souls/{self.handle}/{self.slug}"
@property @property
def raw_md_url(self) -> str: def raw_md_url(self) -> str:
"""Return the raw markdown URL for this soul."""
return f"{SOULS_DIRECTORY_BASE_URL}/api/souls/{self.handle}/{self.slug}.md" return f"{SOULS_DIRECTORY_BASE_URL}/api/souls/{self.handle}/{self.slug}.md"
def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]: def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
"""Parse sitemap XML and extract valid souls.directory handle/slug refs."""
try: try:
root = ET.fromstring(sitemap_xml) # Souls sitemap is fetched from a known trusted host in this service flow.
root = ET.fromstring(sitemap_xml) # noqa: S314
except ET.ParseError: except ET.ParseError:
return [] return []
# Handle both namespaced and non-namespaced sitemap XML. # Handle both namespaced and non-namespaced sitemap XML.
urls: list[str] = [] urls = [
for loc in root.iter(): loc.text.strip()
if loc.tag.endswith("loc") and loc.text: for loc in root.iter()
urls.append(loc.text.strip()) if loc.tag.endswith("loc") and loc.text
]
refs: list[SoulRef] = [] refs: list[SoulRef] = []
for url in urls: for url in urls:
@@ -45,7 +55,7 @@ def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
continue continue
# Expected: https://souls.directory/souls/{handle}/{slug} # Expected: https://souls.directory/souls/{handle}/{slug}
parts = url.split("/") parts = url.split("/")
if len(parts) < 6: if len(parts) < _SOUL_URL_MIN_PARTS:
continue continue
handle = parts[4].strip() handle = parts[4].strip()
slug = parts[5].strip() slug = parts[5].strip()
@@ -61,7 +71,11 @@ _sitemap_cache: dict[str, object] = {
} }
async def list_souls_directory_refs(*, client: httpx.AsyncClient | None = None) -> list[SoulRef]: async def list_souls_directory_refs(
*,
client: httpx.AsyncClient | None = None,
) -> list[SoulRef]:
"""Return cached sitemap-derived soul refs, refreshing when TTL expires."""
now = time.time() now = time.time()
loaded_raw = _sitemap_cache.get("loaded_at") loaded_raw = _sitemap_cache.get("loaded_at")
loaded_at = loaded_raw if isinstance(loaded_raw, (int, float)) else 0.0 loaded_at = loaded_raw if isinstance(loaded_raw, (int, float)) else 0.0
@@ -93,11 +107,15 @@ async def fetch_soul_markdown(
slug: str, slug: str,
client: httpx.AsyncClient | None = None, client: httpx.AsyncClient | None = None,
) -> str: ) -> str:
"""Fetch raw markdown content for a specific handle/slug pair."""
normalized_handle = handle.strip().strip("/") normalized_handle = handle.strip().strip("/")
normalized_slug = slug.strip().strip("/") normalized_slug = slug.strip().strip("/")
if normalized_slug.endswith(".md"): if normalized_slug.endswith(".md"):
normalized_slug = normalized_slug[: -len(".md")] normalized_slug = normalized_slug[: -len(".md")]
url = f"{SOULS_DIRECTORY_BASE_URL}/api/souls/{normalized_handle}/{normalized_slug}.md" url = (
f"{SOULS_DIRECTORY_BASE_URL}/api/souls/"
f"{normalized_handle}/{normalized_slug}.md"
)
owns_client = client is None owns_client = client is None
if client is None: if client is None:
@@ -115,6 +133,7 @@ async def fetch_soul_markdown(
def search_souls(refs: list[SoulRef], *, query: str, limit: int = 20) -> list[SoulRef]: def search_souls(refs: list[SoulRef], *, query: str, limit: int = 20) -> list[SoulRef]:
"""Search refs by case-insensitive handle/slug substring with a hard limit."""
q = query.strip().lower() q = query.strip().lower()
if not q: if not q:
return refs[: max(0, min(limit, len(refs)))] return refs[: max(0, min(limit, len(refs)))]

View File

@@ -0,0 +1 @@
"""Background worker tasks and queue processing utilities."""

View File

@@ -1,3 +1,5 @@
"""RQ queue and Redis connection helpers for background workers."""
from __future__ import annotations from __future__ import annotations
from redis import Redis from redis import Redis
@@ -7,8 +9,10 @@ from app.core.config import settings
def get_redis() -> Redis: def get_redis() -> Redis:
"""Create a Redis client from configured settings."""
return Redis.from_url(settings.redis_url) return Redis.from_url(settings.redis_url)
def get_queue(name: str) -> Queue: def get_queue(name: str) -> Queue:
"""Return an RQ queue bound to the configured Redis connection."""
return Queue(name, connection=get_redis()) return Queue(name, connection=get_redis())

View File

@@ -0,0 +1 @@
"""Alembic migration package for backend schema evolution."""

View File

@@ -1,3 +1,5 @@
"""Alembic environment configuration for backend database migrations."""
from __future__ import annotations from __future__ import annotations
import sys import sys
@@ -16,8 +18,9 @@ from app import models # noqa: E402,F401
from app.core.config import settings # noqa: E402 from app.core.config import settings # noqa: E402
config = context.config config = context.config
configure_logger = config.attributes.get("configure_logger", True)
if config.config_file_name is not None and config.attributes.get("configure_logger", True): if config.config_file_name is not None and configure_logger:
fileConfig(config.config_file_name) fileConfig(config.config_file_name)
target_metadata = SQLModel.metadata target_metadata = SQLModel.metadata
@@ -33,6 +36,7 @@ def _normalize_database_url(database_url: str) -> str:
def get_url() -> str: def get_url() -> str:
"""Return the normalized SQLAlchemy database URL for Alembic."""
return _normalize_database_url(settings.database_url) return _normalize_database_url(settings.database_url)
@@ -40,6 +44,7 @@ config.set_main_option("sqlalchemy.url", get_url())
def run_migrations_offline() -> None: def run_migrations_offline() -> None:
"""Run migrations in offline mode without DB engine connectivity."""
context.configure( context.configure(
url=get_url(), url=get_url(),
target_metadata=target_metadata, target_metadata=target_metadata,
@@ -52,6 +57,7 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None: def run_migrations_online() -> None:
"""Run migrations in online mode using a live DB connection."""
configuration = config.get_section(config.config_ini_section) or {} configuration = config.get_section(config.config_ini_section) or {}
configuration["sqlalchemy.url"] = get_url() configuration["sqlalchemy.url"] = get_url()

View File

@@ -1,13 +1,14 @@
"""init """Initial schema migration.
Revision ID: 658dca8f4a11 Revision ID: 658dca8f4a11
Revises: Revises:
Create Date: 2026-02-09 00:41:55.760624 Create Date: 2026-02-09 00:41:55.760624
""" """
from __future__ import annotations from __future__ import annotations
# ruff: noqa: INP001
import sqlalchemy as sa import sqlalchemy as sa
import sqlmodel import sqlmodel
from alembic import op from alembic import op
@@ -19,7 +20,8 @@ branch_labels = None
depends_on = None depends_on = None
def upgrade() -> None: def upgrade() -> None: # noqa: PLR0915
"""Create initial schema objects."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table( op.create_table(
"organizations", "organizations",
@@ -30,7 +32,9 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name", name="uq_organizations_name"), sa.UniqueConstraint("name", name="uq_organizations_name"),
) )
op.create_index(op.f("ix_organizations_name"), "organizations", ["name"], unique=False) op.create_index(
op.f("ix_organizations_name"), "organizations", ["name"], unique=False,
)
op.create_table( op.create_table(
"board_groups", "board_groups",
sa.Column("id", sa.Uuid(), nullable=False), sa.Column("id", sa.Uuid(), nullable=False),
@@ -47,9 +51,14 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index( op.create_index(
op.f("ix_board_groups_organization_id"), "board_groups", ["organization_id"], unique=False op.f("ix_board_groups_organization_id"),
"board_groups",
["organization_id"],
unique=False,
)
op.create_index(
op.f("ix_board_groups_slug"), "board_groups", ["slug"], unique=False,
) )
op.create_index(op.f("ix_board_groups_slug"), "board_groups", ["slug"], unique=False)
op.create_table( op.create_table(
"gateways", "gateways",
sa.Column("id", sa.Uuid(), nullable=False), sa.Column("id", sa.Uuid(), nullable=False),
@@ -57,7 +66,9 @@ def upgrade() -> None:
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("url", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("url", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("token", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column("token", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("main_session_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
"main_session_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False,
),
sa.Column("workspace_root", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("workspace_root", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False), sa.Column("updated_at", sa.DateTime(), nullable=False),
@@ -68,7 +79,10 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index( op.create_index(
op.f("ix_gateways_organization_id"), "gateways", ["organization_id"], unique=False op.f("ix_gateways_organization_id"),
"gateways",
["organization_id"],
unique=False,
) )
op.create_table( op.create_table(
"users", "users",
@@ -90,9 +104,14 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index( op.create_index(
op.f("ix_users_active_organization_id"), "users", ["active_organization_id"], unique=False op.f("ix_users_active_organization_id"),
"users",
["active_organization_id"],
unique=False,
)
op.create_index(
op.f("ix_users_clerk_user_id"), "users", ["clerk_user_id"], unique=True,
) )
op.create_index(op.f("ix_users_clerk_user_id"), "users", ["clerk_user_id"], unique=True)
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=False) op.create_index(op.f("ix_users_email"), "users", ["email"], unique=False)
op.create_table( op.create_table(
"board_group_memory", "board_group_memory",
@@ -116,7 +135,10 @@ def upgrade() -> None:
unique=False, unique=False,
) )
op.create_index( op.create_index(
op.f("ix_board_group_memory_is_chat"), "board_group_memory", ["is_chat"], unique=False op.f("ix_board_group_memory_is_chat"),
"board_group_memory",
["is_chat"],
unique=False,
) )
op.create_table( op.create_table(
"boards", "boards",
@@ -148,10 +170,18 @@ def upgrade() -> None:
), ),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index(op.f("ix_boards_board_group_id"), "boards", ["board_group_id"], unique=False) op.create_index(
op.create_index(op.f("ix_boards_board_type"), "boards", ["board_type"], unique=False) op.f("ix_boards_board_group_id"), "boards", ["board_group_id"], unique=False,
op.create_index(op.f("ix_boards_gateway_id"), "boards", ["gateway_id"], unique=False) )
op.create_index(op.f("ix_boards_organization_id"), "boards", ["organization_id"], unique=False) op.create_index(
op.f("ix_boards_board_type"), "boards", ["board_type"], unique=False,
)
op.create_index(
op.f("ix_boards_gateway_id"), "boards", ["gateway_id"], unique=False,
)
op.create_index(
op.f("ix_boards_organization_id"), "boards", ["organization_id"], unique=False,
)
op.create_index(op.f("ix_boards_slug"), "boards", ["slug"], unique=False) op.create_index(op.f("ix_boards_slug"), "boards", ["slug"], unique=False)
op.create_table( op.create_table(
"organization_invites", "organization_invites",
@@ -207,10 +237,16 @@ def upgrade() -> None:
unique=False, unique=False,
) )
op.create_index( op.create_index(
op.f("ix_organization_invites_role"), "organization_invites", ["role"], unique=False op.f("ix_organization_invites_role"),
"organization_invites",
["role"],
unique=False,
) )
op.create_index( op.create_index(
op.f("ix_organization_invites_token"), "organization_invites", ["token"], unique=False op.f("ix_organization_invites_token"),
"organization_invites",
["token"],
unique=False,
) )
op.create_table( op.create_table(
"organization_members", "organization_members",
@@ -231,7 +267,9 @@ def upgrade() -> None:
["users.id"], ["users.id"],
), ),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("organization_id", "user_id", name="uq_organization_members_org_user"), sa.UniqueConstraint(
"organization_id", "user_id", name="uq_organization_members_org_user",
),
) )
op.create_index( op.create_index(
op.f("ix_organization_members_organization_id"), op.f("ix_organization_members_organization_id"),
@@ -240,10 +278,16 @@ def upgrade() -> None:
unique=False, unique=False,
) )
op.create_index( op.create_index(
op.f("ix_organization_members_role"), "organization_members", ["role"], unique=False op.f("ix_organization_members_role"),
"organization_members",
["role"],
unique=False,
) )
op.create_index( op.create_index(
op.f("ix_organization_members_user_id"), "organization_members", ["user_id"], unique=False op.f("ix_organization_members_user_id"),
"organization_members",
["user_id"],
unique=False,
) )
op.create_table( op.create_table(
"agents", "agents",
@@ -251,19 +295,31 @@ def upgrade() -> None:
sa.Column("board_id", sa.Uuid(), nullable=True), sa.Column("board_id", sa.Uuid(), nullable=True),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("status", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("status", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("openclaw_session_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column(
sa.Column("agent_token_hash", sqlmodel.sql.sqltypes.AutoString(), nullable=True), "openclaw_session_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True,
),
sa.Column(
"agent_token_hash", sqlmodel.sql.sqltypes.AutoString(), nullable=True,
),
sa.Column("heartbeat_config", sa.JSON(), nullable=True), sa.Column("heartbeat_config", sa.JSON(), nullable=True),
sa.Column("identity_profile", sa.JSON(), nullable=True), sa.Column("identity_profile", sa.JSON(), nullable=True),
sa.Column("identity_template", sa.Text(), nullable=True), sa.Column("identity_template", sa.Text(), nullable=True),
sa.Column("soul_template", sa.Text(), nullable=True), sa.Column("soul_template", sa.Text(), nullable=True),
sa.Column("provision_requested_at", sa.DateTime(), nullable=True), sa.Column("provision_requested_at", sa.DateTime(), nullable=True),
sa.Column( sa.Column(
"provision_confirm_token_hash", sqlmodel.sql.sqltypes.AutoString(), nullable=True "provision_confirm_token_hash",
sqlmodel.sql.sqltypes.AutoString(),
nullable=True,
),
sa.Column(
"provision_action", sqlmodel.sql.sqltypes.AutoString(), nullable=True,
), ),
sa.Column("provision_action", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("delete_requested_at", sa.DateTime(), nullable=True), sa.Column("delete_requested_at", sa.DateTime(), nullable=True),
sa.Column("delete_confirm_token_hash", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column(
"delete_confirm_token_hash",
sqlmodel.sql.sqltypes.AutoString(),
nullable=True,
),
sa.Column("last_seen_at", sa.DateTime(), nullable=True), sa.Column("last_seen_at", sa.DateTime(), nullable=True),
sa.Column("is_board_lead", sa.Boolean(), nullable=False), sa.Column("is_board_lead", sa.Boolean(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False),
@@ -275,7 +331,10 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index( op.create_index(
op.f("ix_agents_agent_token_hash"), "agents", ["agent_token_hash"], unique=False op.f("ix_agents_agent_token_hash"),
"agents",
["agent_token_hash"],
unique=False,
) )
op.create_index(op.f("ix_agents_board_id"), "agents", ["board_id"], unique=False) op.create_index(op.f("ix_agents_board_id"), "agents", ["board_id"], unique=False)
op.create_index( op.create_index(
@@ -284,13 +343,21 @@ def upgrade() -> None:
["delete_confirm_token_hash"], ["delete_confirm_token_hash"],
unique=False, unique=False,
) )
op.create_index(op.f("ix_agents_is_board_lead"), "agents", ["is_board_lead"], unique=False) op.create_index(
op.f("ix_agents_is_board_lead"), "agents", ["is_board_lead"], unique=False,
)
op.create_index(op.f("ix_agents_name"), "agents", ["name"], unique=False) op.create_index(op.f("ix_agents_name"), "agents", ["name"], unique=False)
op.create_index( op.create_index(
op.f("ix_agents_openclaw_session_id"), "agents", ["openclaw_session_id"], unique=False op.f("ix_agents_openclaw_session_id"),
"agents",
["openclaw_session_id"],
unique=False,
) )
op.create_index( op.create_index(
op.f("ix_agents_provision_action"), "agents", ["provision_action"], unique=False op.f("ix_agents_provision_action"),
"agents",
["provision_action"],
unique=False,
) )
op.create_index( op.create_index(
op.f("ix_agents_provision_confirm_token_hash"), op.f("ix_agents_provision_confirm_token_hash"),
@@ -314,8 +381,12 @@ def upgrade() -> None:
), ),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index(op.f("ix_board_memory_board_id"), "board_memory", ["board_id"], unique=False) op.create_index(
op.create_index(op.f("ix_board_memory_is_chat"), "board_memory", ["is_chat"], unique=False) op.f("ix_board_memory_board_id"), "board_memory", ["board_id"], unique=False,
)
op.create_index(
op.f("ix_board_memory_is_chat"), "board_memory", ["is_chat"], unique=False,
)
op.create_table( op.create_table(
"board_onboarding_sessions", "board_onboarding_sessions",
sa.Column("id", sa.Uuid(), nullable=False), sa.Column("id", sa.Uuid(), nullable=False),
@@ -363,7 +434,9 @@ def upgrade() -> None:
), ),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint( sa.UniqueConstraint(
"organization_member_id", "board_id", name="uq_org_board_access_member_board" "organization_member_id",
"board_id",
name="uq_org_board_access_member_board",
), ),
) )
op.create_index( op.create_index(
@@ -397,7 +470,9 @@ def upgrade() -> None:
), ),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint( sa.UniqueConstraint(
"organization_invite_id", "board_id", name="uq_org_invite_board_access_invite_board" "organization_invite_id",
"board_id",
name="uq_org_invite_board_access_invite_board",
), ),
) )
op.create_index( op.create_index(
@@ -443,11 +518,17 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index( op.create_index(
op.f("ix_tasks_assigned_agent_id"), "tasks", ["assigned_agent_id"], unique=False op.f("ix_tasks_assigned_agent_id"),
"tasks",
["assigned_agent_id"],
unique=False,
) )
op.create_index(op.f("ix_tasks_board_id"), "tasks", ["board_id"], unique=False) op.create_index(op.f("ix_tasks_board_id"), "tasks", ["board_id"], unique=False)
op.create_index( op.create_index(
op.f("ix_tasks_created_by_user_id"), "tasks", ["created_by_user_id"], unique=False op.f("ix_tasks_created_by_user_id"),
"tasks",
["created_by_user_id"],
unique=False,
) )
op.create_index(op.f("ix_tasks_priority"), "tasks", ["priority"], unique=False) op.create_index(op.f("ix_tasks_priority"), "tasks", ["priority"], unique=False)
op.create_index(op.f("ix_tasks_status"), "tasks", ["status"], unique=False) op.create_index(op.f("ix_tasks_status"), "tasks", ["status"], unique=False)
@@ -470,13 +551,22 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index( op.create_index(
op.f("ix_activity_events_agent_id"), "activity_events", ["agent_id"], unique=False op.f("ix_activity_events_agent_id"),
"activity_events",
["agent_id"],
unique=False,
) )
op.create_index( op.create_index(
op.f("ix_activity_events_event_type"), "activity_events", ["event_type"], unique=False op.f("ix_activity_events_event_type"),
"activity_events",
["event_type"],
unique=False,
) )
op.create_index( op.create_index(
op.f("ix_activity_events_task_id"), "activity_events", ["task_id"], unique=False op.f("ix_activity_events_task_id"),
"activity_events",
["task_id"],
unique=False,
) )
op.create_table( op.create_table(
"approvals", "approvals",
@@ -505,10 +595,16 @@ def upgrade() -> None:
), ),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index(op.f("ix_approvals_agent_id"), "approvals", ["agent_id"], unique=False) op.create_index(
op.create_index(op.f("ix_approvals_board_id"), "approvals", ["board_id"], unique=False) op.f("ix_approvals_agent_id"), "approvals", ["agent_id"], unique=False,
)
op.create_index(
op.f("ix_approvals_board_id"), "approvals", ["board_id"], unique=False,
)
op.create_index(op.f("ix_approvals_status"), "approvals", ["status"], unique=False) op.create_index(op.f("ix_approvals_status"), "approvals", ["status"], unique=False)
op.create_index(op.f("ix_approvals_task_id"), "approvals", ["task_id"], unique=False) op.create_index(
op.f("ix_approvals_task_id"), "approvals", ["task_id"], unique=False,
)
op.create_table( op.create_table(
"task_dependencies", "task_dependencies",
sa.Column("id", sa.Uuid(), nullable=False), sa.Column("id", sa.Uuid(), nullable=False),
@@ -516,7 +612,9 @@ def upgrade() -> None:
sa.Column("task_id", sa.Uuid(), nullable=False), sa.Column("task_id", sa.Uuid(), nullable=False),
sa.Column("depends_on_task_id", sa.Uuid(), nullable=False), sa.Column("depends_on_task_id", sa.Uuid(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False),
sa.CheckConstraint("task_id <> depends_on_task_id", name="ck_task_dependencies_no_self"), sa.CheckConstraint(
"task_id <> depends_on_task_id", name="ck_task_dependencies_no_self",
),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
["board_id"], ["board_id"],
["boards.id"], ["boards.id"],
@@ -531,11 +629,16 @@ def upgrade() -> None:
), ),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint( sa.UniqueConstraint(
"task_id", "depends_on_task_id", name="uq_task_dependencies_task_id_depends_on_task_id" "task_id",
"depends_on_task_id",
name="uq_task_dependencies_task_id_depends_on_task_id",
), ),
) )
op.create_index( op.create_index(
op.f("ix_task_dependencies_board_id"), "task_dependencies", ["board_id"], unique=False op.f("ix_task_dependencies_board_id"),
"task_dependencies",
["board_id"],
unique=False,
) )
op.create_index( op.create_index(
op.f("ix_task_dependencies_depends_on_task_id"), op.f("ix_task_dependencies_depends_on_task_id"),
@@ -544,13 +647,18 @@ def upgrade() -> None:
unique=False, unique=False,
) )
op.create_index( op.create_index(
op.f("ix_task_dependencies_task_id"), "task_dependencies", ["task_id"], unique=False op.f("ix_task_dependencies_task_id"),
"task_dependencies",
["task_id"],
unique=False,
) )
op.create_table( op.create_table(
"task_fingerprints", "task_fingerprints",
sa.Column("id", sa.Uuid(), nullable=False), sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("board_id", sa.Uuid(), nullable=False), sa.Column("board_id", sa.Uuid(), nullable=False),
sa.Column("fingerprint_hash", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
"fingerprint_hash", sqlmodel.sql.sqltypes.AutoString(), nullable=False,
),
sa.Column("task_id", sa.Uuid(), nullable=False), sa.Column("task_id", sa.Uuid(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
@@ -564,7 +672,10 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index( op.create_index(
op.f("ix_task_fingerprints_board_id"), "task_fingerprints", ["board_id"], unique=False op.f("ix_task_fingerprints_board_id"),
"task_fingerprints",
["board_id"],
unique=False,
) )
op.create_index( op.create_index(
op.f("ix_task_fingerprints_fingerprint_hash"), op.f("ix_task_fingerprints_fingerprint_hash"),
@@ -575,13 +686,18 @@ def upgrade() -> None:
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None: # noqa: PLR0915
"""Drop initial schema objects."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_task_fingerprints_fingerprint_hash"), table_name="task_fingerprints") op.drop_index(
op.f("ix_task_fingerprints_fingerprint_hash"), table_name="task_fingerprints",
)
op.drop_index(op.f("ix_task_fingerprints_board_id"), table_name="task_fingerprints") op.drop_index(op.f("ix_task_fingerprints_board_id"), table_name="task_fingerprints")
op.drop_table("task_fingerprints") op.drop_table("task_fingerprints")
op.drop_index(op.f("ix_task_dependencies_task_id"), table_name="task_dependencies") op.drop_index(op.f("ix_task_dependencies_task_id"), table_name="task_dependencies")
op.drop_index(op.f("ix_task_dependencies_depends_on_task_id"), table_name="task_dependencies") op.drop_index(
op.f("ix_task_dependencies_depends_on_task_id"), table_name="task_dependencies",
)
op.drop_index(op.f("ix_task_dependencies_board_id"), table_name="task_dependencies") op.drop_index(op.f("ix_task_dependencies_board_id"), table_name="task_dependencies")
op.drop_table("task_dependencies") op.drop_table("task_dependencies")
op.drop_index(op.f("ix_approvals_task_id"), table_name="approvals") op.drop_index(op.f("ix_approvals_task_id"), table_name="approvals")
@@ -613,14 +729,17 @@ def downgrade() -> None:
table_name="organization_board_access", table_name="organization_board_access",
) )
op.drop_index( op.drop_index(
op.f("ix_organization_board_access_board_id"), table_name="organization_board_access" op.f("ix_organization_board_access_board_id"),
table_name="organization_board_access",
) )
op.drop_table("organization_board_access") op.drop_table("organization_board_access")
op.drop_index( op.drop_index(
op.f("ix_board_onboarding_sessions_status"), table_name="board_onboarding_sessions" op.f("ix_board_onboarding_sessions_status"),
table_name="board_onboarding_sessions",
) )
op.drop_index( op.drop_index(
op.f("ix_board_onboarding_sessions_board_id"), table_name="board_onboarding_sessions" op.f("ix_board_onboarding_sessions_board_id"),
table_name="board_onboarding_sessions",
) )
op.drop_table("board_onboarding_sessions") op.drop_table("board_onboarding_sessions")
op.drop_index(op.f("ix_board_memory_is_chat"), table_name="board_memory") op.drop_index(op.f("ix_board_memory_is_chat"), table_name="board_memory")
@@ -636,23 +755,38 @@ def downgrade() -> None:
op.drop_index(op.f("ix_agents_board_id"), table_name="agents") op.drop_index(op.f("ix_agents_board_id"), table_name="agents")
op.drop_index(op.f("ix_agents_agent_token_hash"), table_name="agents") op.drop_index(op.f("ix_agents_agent_token_hash"), table_name="agents")
op.drop_table("agents") op.drop_table("agents")
op.drop_index(op.f("ix_organization_members_user_id"), table_name="organization_members")
op.drop_index(op.f("ix_organization_members_role"), table_name="organization_members")
op.drop_index( op.drop_index(
op.f("ix_organization_members_organization_id"), table_name="organization_members" op.f("ix_organization_members_user_id"), table_name="organization_members",
)
op.drop_index(
op.f("ix_organization_members_role"), table_name="organization_members",
)
op.drop_index(
op.f("ix_organization_members_organization_id"),
table_name="organization_members",
) )
op.drop_table("organization_members") op.drop_table("organization_members")
op.drop_index(op.f("ix_organization_invites_token"), table_name="organization_invites")
op.drop_index(op.f("ix_organization_invites_role"), table_name="organization_invites")
op.drop_index( op.drop_index(
op.f("ix_organization_invites_organization_id"), table_name="organization_invites" op.f("ix_organization_invites_token"), table_name="organization_invites",
)
op.drop_index(op.f("ix_organization_invites_invited_email"), table_name="organization_invites")
op.drop_index(
op.f("ix_organization_invites_created_by_user_id"), table_name="organization_invites"
) )
op.drop_index( op.drop_index(
op.f("ix_organization_invites_accepted_by_user_id"), table_name="organization_invites" op.f("ix_organization_invites_role"), table_name="organization_invites",
)
op.drop_index(
op.f("ix_organization_invites_organization_id"),
table_name="organization_invites",
)
op.drop_index(
op.f("ix_organization_invites_invited_email"),
table_name="organization_invites",
)
op.drop_index(
op.f("ix_organization_invites_created_by_user_id"),
table_name="organization_invites",
)
op.drop_index(
op.f("ix_organization_invites_accepted_by_user_id"),
table_name="organization_invites",
) )
op.drop_table("organization_invites") op.drop_table("organization_invites")
op.drop_index(op.f("ix_boards_slug"), table_name="boards") op.drop_index(op.f("ix_boards_slug"), table_name="boards")
@@ -661,8 +795,12 @@ def downgrade() -> None:
op.drop_index(op.f("ix_boards_board_type"), table_name="boards") op.drop_index(op.f("ix_boards_board_type"), table_name="boards")
op.drop_index(op.f("ix_boards_board_group_id"), table_name="boards") op.drop_index(op.f("ix_boards_board_group_id"), table_name="boards")
op.drop_table("boards") op.drop_table("boards")
op.drop_index(op.f("ix_board_group_memory_is_chat"), table_name="board_group_memory") op.drop_index(
op.drop_index(op.f("ix_board_group_memory_board_group_id"), table_name="board_group_memory") op.f("ix_board_group_memory_is_chat"), table_name="board_group_memory",
)
op.drop_index(
op.f("ix_board_group_memory_board_group_id"), table_name="board_group_memory",
)
op.drop_table("board_group_memory") op.drop_table("board_group_memory")
op.drop_index(op.f("ix_users_email"), table_name="users") op.drop_index(op.f("ix_users_email"), table_name="users")
op.drop_index(op.f("ix_users_clerk_user_id"), table_name="users") op.drop_index(op.f("ix_users_clerk_user_id"), table_name="users")

View File

@@ -0,0 +1 @@
"""Utility scripts for backend development and maintenance tasks."""

View File

@@ -1,3 +1,5 @@
"""Export the backend OpenAPI schema to a versioned JSON artifact."""
from __future__ import annotations from __future__ import annotations
import json import json
@@ -11,11 +13,13 @@ from app.main import app # noqa: E402
def main() -> None: def main() -> None:
# Importing the FastAPI app does not run lifespan hooks, so this does not require a DB. """Generate `openapi.json` from the FastAPI app definition."""
# Importing the FastAPI app does not run lifespan hooks,
# so this does not require a DB.
out_path = BACKEND_ROOT / "openapi.json" out_path = BACKEND_ROOT / "openapi.json"
payload = app.openapi() payload = app.openapi()
out_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") out_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
print(str(out_path)) sys.stdout.write(f"{out_path}\n")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,3 +1,5 @@
"""Seed a minimal local demo dataset for manual development flows."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@@ -16,14 +18,16 @@ from app.models.users import User # noqa: E402
async def run() -> None: async def run() -> None:
"""Populate the local database with a demo gateway, board, user, and agent."""
await init_db() await init_db()
async with async_session_maker() as session: async with async_session_maker() as session:
demo_workspace_root = BACKEND_ROOT / ".tmp" / "openclaw-demo"
gateway = Gateway( gateway = Gateway(
name="Demo Gateway", name="Demo Gateway",
url="http://localhost:8080", url="http://localhost:8080",
token=None, token=None,
main_session_key="demo:main", main_session_key="demo:main",
workspace_root="/tmp/openclaw-demo", workspace_root=str(demo_workspace_root),
) )
session.add(gateway) session.add(gateway)
await session.commit() await session.commit()

View File

@@ -1,3 +1,6 @@
# ruff: noqa: INP001
"""CLI script to sync template files into gateway agent workspaces."""
from __future__ import annotations from __future__ import annotations
import argparse import argparse
@@ -16,10 +19,15 @@ from app.services.template_sync import sync_gateway_templates # noqa: E402
def _parse_args() -> argparse.Namespace: def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Sync templates/ to existing OpenClaw gateway agent workspaces." description="Sync templates/ to existing OpenClaw gateway agent workspaces.",
) )
parser.add_argument("--gateway-id", type=str, required=True, help="Gateway UUID") parser.add_argument("--gateway-id", type=str, required=True, help="Gateway UUID")
parser.add_argument("--board-id", type=str, default=None, help="Optional Board UUID filter") parser.add_argument(
"--board-id",
type=str,
default=None,
help="Optional Board UUID filter",
)
parser.add_argument( parser.add_argument(
"--include-main", "--include-main",
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
@@ -29,12 +37,18 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--reset-sessions", "--reset-sessions",
action="store_true", action="store_true",
help="Reset agent sessions after syncing files (forces agents to re-read workspace)", help=(
"Reset agent sessions after syncing files "
"(forces agents to re-read workspace)"
),
) )
parser.add_argument( parser.add_argument(
"--rotate-tokens", "--rotate-tokens",
action="store_true", action="store_true",
help="Rotate agent tokens when TOOLS.md is missing/unreadable or token drift is detected", help=(
"Rotate agent tokens when TOOLS.md is missing/unreadable "
"or token drift is detected"
),
) )
parser.add_argument( parser.add_argument(
"--force-bootstrap", "--force-bootstrap",
@@ -52,7 +66,8 @@ async def _run() -> int:
async with async_session_maker() as session: async with async_session_maker() as session:
gateway = await session.get(Gateway, gateway_id) gateway = await session.get(Gateway, gateway_id)
if gateway is None: if gateway is None:
raise SystemExit(f"Gateway not found: {gateway_id}") message = f"Gateway not found: {gateway_id}"
raise SystemExit(message)
result = await sync_gateway_templates( result = await sync_gateway_templates(
session, session,
@@ -65,21 +80,29 @@ async def _run() -> int:
board_id=board_id, board_id=board_id,
) )
print(f"gateway_id={result.gateway_id}") sys.stdout.write(f"gateway_id={result.gateway_id}\n")
print(f"include_main={result.include_main} reset_sessions={result.reset_sessions}") sys.stdout.write(
print( f"include_main={result.include_main} "
f"agents_updated={result.agents_updated} agents_skipped={result.agents_skipped} main_updated={result.main_updated}" f"reset_sessions={result.reset_sessions}\n",
)
sys.stdout.write(
f"agents_updated={result.agents_updated} "
f"agents_skipped={result.agents_skipped} "
f"main_updated={result.main_updated}\n",
) )
if result.errors: if result.errors:
print("errors:") sys.stdout.write("errors:\n")
for err in result.errors: for err in result.errors:
agent = f"{err.agent_name} ({err.agent_id})" if err.agent_id else "n/a" agent = f"{err.agent_name} ({err.agent_id})" if err.agent_id else "n/a"
print(f"- agent={agent} board_id={err.board_id} message={err.message}") sys.stdout.write(
f"- agent={agent} board_id={err.board_id} message={err.message}\n",
)
return 1 return 1
return 0 return 0
def main() -> None: def main() -> None:
"""Run the async CLI workflow and exit with its return code."""
raise SystemExit(asyncio.run(_run())) raise SystemExit(asyncio.run(_run()))

View File

@@ -0,0 +1 @@
"""Backend test package."""

View File

@@ -1,3 +1,6 @@
# ruff: noqa: INP001
"""Pytest configuration shared across backend tests."""
import sys import sys
from pathlib import Path from pathlib import Path

View File

@@ -1,24 +1,30 @@
# ruff: noqa: INP001, S101
"""Regression test for board-group delete ordering."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any from typing import TYPE_CHECKING, cast
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from app.api import board_groups from app.api import board_groups
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
@dataclass @dataclass
class _FakeSession: class _FakeSession:
executed: list[Any] = field(default_factory=list) executed: list[object] = field(default_factory=list)
committed: int = 0 committed: int = 0
async def exec(self, statement: Any) -> None: async def exec(self, statement: object) -> None:
self.executed.append(statement) self.executed.append(statement)
async def execute(self, statement: Any) -> None: async def execute(self, statement: object) -> None:
self.executed.append(statement) self.executed.append(statement)
async def commit(self) -> None: async def commit(self) -> None:
@@ -29,17 +35,26 @@ class _FakeSession:
async def test_delete_board_group_cleans_group_memory_first( async def test_delete_board_group_cleans_group_memory_first(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
"""Delete should remove boards, memory, then the board-group record."""
group_id = uuid4() group_id = uuid4()
async def _fake_require_group_access(*_args: Any, **_kwargs: Any) -> None: async def _fake_require_group_access(*_args: object, **_kwargs: object) -> None:
return None return None
monkeypatch.setattr(board_groups, "_require_group_access", _fake_require_group_access) monkeypatch.setattr(
board_groups,
"_require_group_access",
_fake_require_group_access,
)
session = _FakeSession() session = _FakeSession()
ctx = SimpleNamespace(member=object()) ctx = SimpleNamespace(member=object())
await board_groups.delete_board_group(group_id=group_id, session=session, ctx=ctx) await board_groups.delete_board_group(
group_id=group_id,
session=cast("AsyncSession", session),
ctx=ctx,
)
statement_tables = [statement.table.name for statement in session.executed] statement_tables = [statement.table.name for statement in session.executed]
assert statement_tables == ["boards", "board_group_memory", "board_groups"] assert statement_tables == ["boards", "board_group_memory", "board_groups"]

View File

@@ -1,3 +1,6 @@
# ruff: noqa: INP001
"""Schema validation tests for board and onboarding goal requirements."""
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
@@ -6,8 +9,12 @@ from app.schemas.board_onboarding import BoardOnboardingConfirm
from app.schemas.boards import BoardCreate from app.schemas.boards import BoardCreate
def test_goal_board_requires_objective_and_metrics_when_confirmed(): def test_goal_board_requires_objective_and_metrics_when_confirmed() -> None:
with pytest.raises(ValueError): """Confirmed goal boards should require objective and success metrics."""
with pytest.raises(
ValueError,
match="Confirmed goal boards require objective and success_metrics",
):
BoardCreate( BoardCreate(
name="Goal Board", name="Goal Board",
slug="goal", slug="goal",
@@ -27,22 +34,39 @@ def test_goal_board_requires_objective_and_metrics_when_confirmed():
) )
def test_goal_board_allows_missing_objective_before_confirmation(): def test_goal_board_allows_missing_objective_before_confirmation() -> None:
"""Draft goal boards may omit objective/success_metrics before confirmation."""
BoardCreate(name="Draft", slug="draft", gateway_id=uuid4(), board_type="goal") BoardCreate(name="Draft", slug="draft", gateway_id=uuid4(), board_type="goal")
def test_general_board_allows_missing_objective(): def test_general_board_allows_missing_objective() -> None:
BoardCreate(name="General", slug="general", gateway_id=uuid4(), board_type="general") """General boards should allow missing goal-specific fields."""
BoardCreate(
name="General",
slug="general",
gateway_id=uuid4(),
board_type="general",
)
def test_onboarding_confirm_requires_goal_fields(): def test_onboarding_confirm_requires_goal_fields() -> None:
with pytest.raises(ValueError): """Onboarding confirm should enforce goal fields for goal board types."""
with pytest.raises(
ValueError,
match="Confirmed goal boards require objective and success_metrics",
):
BoardOnboardingConfirm(board_type="goal") BoardOnboardingConfirm(board_type="goal")
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Confirmed goal boards require objective and success_metrics",
):
BoardOnboardingConfirm(board_type="goal", objective="Ship onboarding") BoardOnboardingConfirm(board_type="goal", objective="Ship onboarding")
with pytest.raises(ValueError): with pytest.raises(
ValueError,
match="Confirmed goal boards require objective and success_metrics",
):
BoardOnboardingConfirm(board_type="goal", success_metrics={"emails": 3}) BoardOnboardingConfirm(board_type="goal", success_metrics={"emails": 3})
BoardOnboardingConfirm( BoardOnboardingConfirm(

View File

@@ -1,7 +1,10 @@
# ruff: noqa: INP001, S101
"""Regression tests for board deletion cleanup behavior."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import TYPE_CHECKING, cast
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
@@ -9,27 +12,32 @@ import pytest
from app.api import boards from app.api import boards
from app.models.boards import Board from app.models.boards import Board
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
_NO_EXEC_RESULTS_ERROR = "No more exec_results left for session.exec"
@dataclass @dataclass
class _FakeSession: class _FakeSession:
exec_results: list[Any] exec_results: list[object]
executed: list[Any] = field(default_factory=list) executed: list[object] = field(default_factory=list)
deleted: list[Any] = field(default_factory=list) deleted: list[object] = field(default_factory=list)
committed: int = 0 committed: int = 0
async def exec(self, statement: Any) -> Any: async def exec(self, statement: object) -> object | None:
is_dml = statement.__class__.__name__ in {"Delete", "Update", "Insert"} is_dml = statement.__class__.__name__ in {"Delete", "Update", "Insert"}
if is_dml: if is_dml:
self.executed.append(statement) self.executed.append(statement)
return None return None
if not self.exec_results: if not self.exec_results:
raise AssertionError("No more exec_results left for session.exec") raise AssertionError(_NO_EXEC_RESULTS_ERROR)
return self.exec_results.pop(0) return self.exec_results.pop(0)
async def execute(self, statement: Any) -> None: async def execute(self, statement: object) -> None:
self.executed.append(statement) self.executed.append(statement)
async def delete(self, value: Any) -> None: async def delete(self, value: object) -> None:
self.deleted.append(value) self.deleted.append(value)
async def commit(self) -> None: async def commit(self) -> None:
@@ -38,6 +46,7 @@ class _FakeSession:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_board_cleans_org_board_access_rows() -> None: async def test_delete_board_cleans_org_board_access_rows() -> None:
"""Deleting a board should clear org-board access rows before commit."""
session = _FakeSession(exec_results=[[], []]) session = _FakeSession(exec_results=[[], []])
board = Board( board = Board(
id=uuid4(), id=uuid4(),
@@ -47,7 +56,10 @@ async def test_delete_board_cleans_org_board_access_rows() -> None:
gateway_id=None, gateway_id=None,
) )
await boards.delete_board(session=session, board=board) await boards.delete_board(
session=cast("AsyncSession", session),
board=board,
)
deleted_table_names = [statement.table.name for statement in session.executed] deleted_table_names = [statement.table.name for statement in session.executed]
assert "organization_board_access" in deleted_table_names assert "organization_board_access" in deleted_table_names

View File

@@ -1,8 +1,11 @@
# ruff: noqa: INP001, S101
"""Tests for organization deletion API behavior and authorization."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any from typing import TYPE_CHECKING, cast
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
@@ -10,16 +13,19 @@ from fastapi import HTTPException, status
from app.api import organizations from app.api import organizations
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
@dataclass @dataclass
class _FakeSession: class _FakeSession:
executed: list[Any] = field(default_factory=list) executed: list[object] = field(default_factory=list)
committed: int = 0 committed: int = 0
async def exec(self, statement: Any) -> None: async def exec(self, statement: object) -> None:
self.executed.append(statement) self.executed.append(statement)
async def execute(self, statement: Any) -> None: async def execute(self, statement: object) -> None:
self.executed.append(statement) self.executed.append(statement)
async def commit(self) -> None: async def commit(self) -> None:
@@ -28,6 +34,7 @@ class _FakeSession:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_my_org_cleans_dependents_before_organization_delete() -> None: async def test_delete_my_org_cleans_dependents_before_organization_delete() -> None:
"""Delete flow should remove dependent rows before the organization row."""
session = _FakeSession() session = _FakeSession()
org_id = uuid4() org_id = uuid4()
ctx = SimpleNamespace( ctx = SimpleNamespace(
@@ -35,7 +42,10 @@ async def test_delete_my_org_cleans_dependents_before_organization_delete() -> N
member=SimpleNamespace(role="owner"), member=SimpleNamespace(role="owner"),
) )
await organizations.delete_my_org(session=session, ctx=ctx) await organizations.delete_my_org(
session=cast("AsyncSession", session),
ctx=ctx,
)
executed_tables = [statement.table.name for statement in session.executed] executed_tables = [statement.table.name for statement in session.executed]
assert executed_tables == [ assert executed_tables == [
@@ -66,6 +76,7 @@ async def test_delete_my_org_cleans_dependents_before_organization_delete() -> N
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_my_org_requires_owner_role() -> None: async def test_delete_my_org_requires_owner_role() -> None:
"""Delete flow should reject non-owner members with HTTP 403."""
session = _FakeSession() session = _FakeSession()
ctx = SimpleNamespace( ctx = SimpleNamespace(
organization=SimpleNamespace(id=uuid4()), organization=SimpleNamespace(id=uuid4()),
@@ -73,7 +84,10 @@ async def test_delete_my_org_requires_owner_role() -> None:
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await organizations.delete_my_org(session=session, ctx=ctx) await organizations.delete_my_org(
session=cast("AsyncSession", session),
ctx=ctx,
)
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
assert session.executed == [] assert session.executed == []

View File

@@ -1,3 +1,5 @@
# ruff: noqa
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -14,7 +16,10 @@ from app.models.organization_invites import OrganizationInvite
from app.models.organization_members import OrganizationMember from app.models.organization_members import OrganizationMember
from app.models.organizations import Organization from app.models.organizations import Organization
from app.models.users import User from app.models.users import User
from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate from app.schemas.organizations import (
OrganizationBoardAccessSpec,
OrganizationMemberAccessUpdate,
)
from app.services import organizations from app.services import organizations
@@ -82,7 +87,10 @@ class _FakeSession:
def test_normalize_invited_email_strips_and_lowercases() -> None: def test_normalize_invited_email_strips_and_lowercases() -> None:
assert organizations.normalize_invited_email(" Foo@Example.com ") == "foo@example.com" assert (
organizations.normalize_invited_email(" Foo@Example.com ")
== "foo@example.com"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -104,13 +112,13 @@ def test_role_rank_unknown_role_falls_back_to_member_rank() -> None:
def test_is_org_admin_owner_admin_member() -> None: def test_is_org_admin_owner_admin_member() -> None:
assert organizations.is_org_admin( assert organizations.is_org_admin(
OrganizationMember(organization_id=uuid4(), user_id=uuid4(), role="owner") OrganizationMember(organization_id=uuid4(), user_id=uuid4(), role="owner"),
) )
assert organizations.is_org_admin( assert organizations.is_org_admin(
OrganizationMember(organization_id=uuid4(), user_id=uuid4(), role="admin") OrganizationMember(organization_id=uuid4(), user_id=uuid4(), role="admin"),
) )
assert not organizations.is_org_admin( assert not organizations.is_org_admin(
OrganizationMember(organization_id=uuid4(), user_id=uuid4(), role="member") OrganizationMember(organization_id=uuid4(), user_id=uuid4(), role="member"),
) )
@@ -119,7 +127,9 @@ async def test_ensure_member_for_user_returns_existing_membership(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
user = User(clerk_user_id="u1") user = User(clerk_user_id="u1")
existing = OrganizationMember(organization_id=uuid4(), user_id=user.id, role="member") existing = OrganizationMember(
organization_id=uuid4(), user_id=user.id, role="member",
)
async def _fake_get_active(_session: Any, _user: User) -> OrganizationMember: async def _fake_get_active(_session: Any, _user: User) -> OrganizationMember:
return existing return existing
@@ -150,10 +160,12 @@ async def test_ensure_member_for_user_accepts_pending_invite(
async def _fake_find(_session: Any, _email: str) -> OrganizationInvite: async def _fake_find(_session: Any, _email: str) -> OrganizationInvite:
return invite return invite
accepted = OrganizationMember(organization_id=org_id, user_id=user.id, role="member") accepted = OrganizationMember(
organization_id=org_id, user_id=user.id, role="member",
)
async def _fake_accept( async def _fake_accept(
_session: Any, _invite: OrganizationInvite, _user: User _session: Any, _invite: OrganizationInvite, _user: User,
) -> OrganizationMember: ) -> OrganizationMember:
assert _invite is invite assert _invite is invite
assert _user is user assert _user is user
@@ -203,7 +215,9 @@ async def test_has_board_access_denies_cross_org() -> None:
member = OrganizationMember(organization_id=uuid4(), user_id=uuid4(), role="member") member = OrganizationMember(organization_id=uuid4(), user_id=uuid4(), role="member")
board = Board(id=uuid4(), organization_id=uuid4(), name="b", slug="b") board = Board(id=uuid4(), organization_id=uuid4(), name="b", slug="b")
assert ( assert (
await organizations.has_board_access(session, member=member, board=board, write=False) await organizations.has_board_access(
session, member=member, board=board, write=False,
)
is False is False
) )
@@ -211,7 +225,9 @@ async def test_has_board_access_denies_cross_org() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_has_board_access_uses_org_board_access_row_read_and_write() -> None: async def test_has_board_access_uses_org_board_access_row_read_and_write() -> None:
org_id = uuid4() org_id = uuid4()
member = OrganizationMember(id=uuid4(), organization_id=org_id, user_id=uuid4(), role="member") member = OrganizationMember(
id=uuid4(), organization_id=org_id, user_id=uuid4(), role="member",
)
board = Board(id=uuid4(), organization_id=org_id, name="b", slug="b") board = Board(id=uuid4(), organization_id=org_id, name="b", slug="b")
access = OrganizationBoardAccess( access = OrganizationBoardAccess(
@@ -222,7 +238,9 @@ async def test_has_board_access_uses_org_board_access_row_read_and_write() -> No
) )
session = _FakeSession(exec_results=[_FakeExecResult(first_value=access)]) session = _FakeSession(exec_results=[_FakeExecResult(first_value=access)])
assert ( assert (
await organizations.has_board_access(session, member=member, board=board, write=False) await organizations.has_board_access(
session, member=member, board=board, write=False,
)
is True is True
) )
@@ -234,7 +252,9 @@ async def test_has_board_access_uses_org_board_access_row_read_and_write() -> No
) )
session2 = _FakeSession(exec_results=[_FakeExecResult(first_value=access2)]) session2 = _FakeSession(exec_results=[_FakeExecResult(first_value=access2)])
assert ( assert (
await organizations.has_board_access(session2, member=member, board=board, write=False) await organizations.has_board_access(
session2, member=member, board=board, write=False,
)
is True is True
) )
@@ -246,13 +266,17 @@ async def test_has_board_access_uses_org_board_access_row_read_and_write() -> No
) )
session3 = _FakeSession(exec_results=[_FakeExecResult(first_value=access3)]) session3 = _FakeSession(exec_results=[_FakeExecResult(first_value=access3)])
assert ( assert (
await organizations.has_board_access(session3, member=member, board=board, write=True) await organizations.has_board_access(
session3, member=member, board=board, write=True,
)
is False is False
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_board_access_raises_when_no_member(monkeypatch: pytest.MonkeyPatch) -> None: async def test_require_board_access_raises_when_no_member(
monkeypatch: pytest.MonkeyPatch,
) -> None:
user = User(clerk_user_id="u1") user = User(clerk_user_id="u1")
board = Board(id=uuid4(), organization_id=uuid4(), name="b", slug="b") board = Board(id=uuid4(), organization_id=uuid4(), name="b", slug="b")
@@ -263,7 +287,9 @@ async def test_require_board_access_raises_when_no_member(monkeypatch: pytest.Mo
session = _FakeSession(exec_results=[]) session = _FakeSession(exec_results=[])
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
await organizations.require_board_access(session, user=user, board=board, write=False) await organizations.require_board_access(
session, user=user, board=board, write=False,
)
assert exc.value.status_code == 403 assert exc.value.status_code == 403
@@ -271,18 +297,26 @@ async def test_require_board_access_raises_when_no_member(monkeypatch: pytest.Mo
async def test_apply_member_access_update_deletes_existing_and_adds_rows_when_not_all_boards() -> ( async def test_apply_member_access_update_deletes_existing_and_adds_rows_when_not_all_boards() -> (
None None
): ):
member = OrganizationMember(id=uuid4(), organization_id=uuid4(), user_id=uuid4(), role="member") member = OrganizationMember(
id=uuid4(), organization_id=uuid4(), user_id=uuid4(), role="member",
)
update = OrganizationMemberAccessUpdate( update = OrganizationMemberAccessUpdate(
all_boards_read=False, all_boards_read=False,
all_boards_write=False, all_boards_write=False,
board_access=[ board_access=[
OrganizationBoardAccessSpec(board_id=uuid4(), can_read=True, can_write=False), OrganizationBoardAccessSpec(
OrganizationBoardAccessSpec(board_id=uuid4(), can_read=True, can_write=True), board_id=uuid4(), can_read=True, can_write=False,
),
OrganizationBoardAccessSpec(
board_id=uuid4(), can_read=True, can_write=True,
),
], ],
) )
session = _FakeSession(exec_results=[]) session = _FakeSession(exec_results=[])
await organizations.apply_member_access_update(session, member=member, update=update) await organizations.apply_member_access_update(
session, member=member, update=update,
)
# delete statement executed once # delete statement executed once
assert len(session.executed) == 1 assert len(session.executed) == 1
@@ -330,7 +364,7 @@ async def test_apply_invite_to_member_upgrades_role_and_merges_access_rows(
exec_results=[ exec_results=[
[invite_access], [invite_access],
_FakeExecResult(first_value=None), _FakeExecResult(first_value=None),
] ],
) )
await organizations.apply_invite_to_member(session, member=member, invite=invite) await organizations.apply_invite_to_member(session, member=member, invite=invite)

View File

@@ -1,9 +1,13 @@
# ruff: noqa: INP001, S101
"""Unit tests for souls-directory parsing and search helpers."""
from __future__ import annotations from __future__ import annotations
from app.services.souls_directory import SoulRef, _parse_sitemap_soul_refs, search_souls from app.services.souls_directory import SoulRef, _parse_sitemap_soul_refs, search_souls
def test_parse_sitemap_extracts_soul_refs() -> None: def test_parse_sitemap_extracts_soul_refs() -> None:
"""Sitemap parser should emit only valid soul handle/slug refs."""
xml = """<?xml version="1.0" encoding="UTF-8"?> xml = """<?xml version="1.0" encoding="UTF-8"?>
<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"> <urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
<url><loc>https://souls.directory</loc></url> <url><loc>https://souls.directory</loc></url>
@@ -19,6 +23,7 @@ def test_parse_sitemap_extracts_soul_refs() -> None:
def test_search_souls_matches_handle_or_slug() -> None: def test_search_souls_matches_handle_or_slug() -> None:
"""Search should match both handle and slug text, case-insensitively."""
refs = [ refs = [
SoulRef(handle="thedaviddias", slug="code-reviewer"), SoulRef(handle="thedaviddias", slug="code-reviewer"),
SoulRef(handle="thedaviddias", slug="technical-writer"), SoulRef(handle="thedaviddias", slug="technical-writer"),

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from starlette.requests import Request from starlette.requests import Request
@@ -26,7 +24,11 @@ class ClerkHTTPBearer:
def __init__( def __init__(
self, self,
config: ClerkConfig, config: ClerkConfig,
*,
auto_error: bool = ..., auto_error: bool = ...,
add_state: bool = ..., add_state: bool = ...,
) -> None: ... ) -> None: ...
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None: ... async def __call__(
self,
request: Request,
) -> HTTPAuthorizationCredentials | None: ...