refactor: simplify code formatting and improve readability across multiple files

This commit is contained in:
Abhimanyu Saharan
2026-02-09 20:44:05 +05:30
parent 020d02fa22
commit 8f6347dc8d
33 changed files with 393 additions and 427 deletions

View File

@@ -22,10 +22,7 @@ from app.models.activity_events import ActivityEvent
from app.models.agents import Agent from app.models.agents import Agent
from app.models.boards import Board from app.models.boards import Board
from app.models.tasks import Task from app.models.tasks import Task
from app.schemas.activity_events import ( from app.schemas.activity_events import ActivityEventRead, ActivityTaskCommentFeedItemRead
ActivityEventRead,
ActivityTaskCommentFeedItemRead,
)
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.organizations import ( from app.services.organizations import (
OrganizationContext, OrganizationContext,
@@ -198,10 +195,7 @@ async def list_task_comment_feed(
def _transform(items: Sequence[Any]) -> Sequence[Any]: def _transform(items: Sequence[Any]) -> Sequence[Any]:
rows = _coerce_task_comment_rows(items) rows = _coerce_task_comment_rows(items)
return [ return [_feed_item(event, task, board, agent) for event, task, board, agent in rows]
_feed_item(event, task, board, agent)
for event, task, board, agent in rows
]
return await paginate(session, statement, transformer=_transform) return await paginate(session, statement, transformer=_transform)

View File

@@ -53,19 +53,9 @@ from app.schemas.gateway_coordination import (
GatewayMainAskUserResponse, GatewayMainAskUserResponse,
) )
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.schemas.tasks import ( from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
TaskCommentCreate,
TaskCommentRead,
TaskCreate,
TaskRead,
TaskUpdate,
)
from app.services.activity_log import record_activity from app.services.activity_log import record_activity
from app.services.board_leads import ( from app.services.board_leads import LeadAgentOptions, LeadAgentRequest, ensure_board_lead_agent
LeadAgentOptions,
LeadAgentRequest,
ensure_board_lead_agent,
)
from app.services.task_dependencies import ( from app.services.task_dependencies import (
blocked_by_dependency_ids, blocked_by_dependency_ids,
dependency_status_by_id, dependency_status_by_id,
@@ -212,7 +202,8 @@ async def _require_gateway_board(
board = await Board.objects.by_id(board_id).first(session) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found", status_code=status.HTTP_404_NOT_FOUND,
detail="Board not found",
) )
if board.gateway_id != gateway.id: if board.gateway_id != gateway.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@@ -322,7 +313,8 @@ async def create_task(
dependency_ids=normalized_deps, dependency_ids=normalized_deps,
) )
blocked_by = blocked_by_dependency_ids( blocked_by = blocked_by_dependency_ids(
dependency_ids=normalized_deps, status_by_id=dep_status, dependency_ids=normalized_deps,
status_by_id=dep_status,
) )
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"): if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
@@ -393,11 +385,7 @@ async def update_task(
agent_ctx: AgentAuthContext = AGENT_CTX_DEP, agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> TaskRead: ) -> TaskRead:
"""Update a task after board-level access checks.""" """Update a task after board-level access checks."""
if ( if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
agent_ctx.agent.board_id
and task.board_id
and agent_ctx.agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return await tasks_api.update_task( return await tasks_api.update_task(
payload=payload, payload=payload,
@@ -417,11 +405,7 @@ async def list_task_comments(
agent_ctx: AgentAuthContext = AGENT_CTX_DEP, agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> LimitOffsetPage[TaskCommentRead]: ) -> LimitOffsetPage[TaskCommentRead]:
"""List comments for a task visible to the authenticated agent.""" """List comments for a task visible to the authenticated agent."""
if ( if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
agent_ctx.agent.board_id
and task.board_id
and agent_ctx.agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return await tasks_api.list_task_comments( return await tasks_api.list_task_comments(
task=task, task=task,
@@ -430,7 +414,8 @@ async def list_task_comments(
@router.post( @router.post(
"/boards/{board_id}/tasks/{task_id}/comments", response_model=TaskCommentRead, "/boards/{board_id}/tasks/{task_id}/comments",
response_model=TaskCommentRead,
) )
async def create_task_comment( async def create_task_comment(
payload: TaskCommentCreate, payload: TaskCommentCreate,
@@ -439,11 +424,7 @@ async def create_task_comment(
agent_ctx: AgentAuthContext = AGENT_CTX_DEP, agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
) -> ActivityEvent: ) -> ActivityEvent:
"""Create a task comment on behalf of the authenticated agent.""" """Create a task comment on behalf of the authenticated agent."""
if ( if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
agent_ctx.agent.board_id
and task.board_id
and agent_ctx.agent.board_id != task.board_id
):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return await tasks_api.create_task_comment( return await tasks_api.create_task_comment(
payload=payload, payload=payload,
@@ -454,7 +435,8 @@ async def create_task_comment(
@router.get( @router.get(
"/boards/{board_id}/memory", response_model=DefaultLimitOffsetPage[BoardMemoryRead], "/boards/{board_id}/memory",
response_model=DefaultLimitOffsetPage[BoardMemoryRead],
) )
async def list_board_memory( async def list_board_memory(
is_chat: bool | None = IS_CHAT_QUERY, is_chat: bool | None = IS_CHAT_QUERY,
@@ -588,7 +570,9 @@ async def nudge_agent(
config = await _gateway_config(session, board) config = await _gateway_config(session, board)
try: try:
await ensure_session( await ensure_session(
target.openclaw_session_id, config=config, label=target.name, target.openclaw_session_id,
config=config,
label=target.name,
) )
await send_message( await send_message(
message, message,
@@ -605,7 +589,8 @@ async def nudge_agent(
) )
await session.commit() await session.commit()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc ) from exc
record_activity( record_activity(
session, session,
@@ -657,7 +642,8 @@ async def get_agent_soul(
) )
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc ) from exc
if isinstance(payload, str): if isinstance(payload, str):
return payload return payload
@@ -671,7 +657,8 @@ async def get_agent_soul(
if isinstance(nested, str): if isinstance(nested, str):
return nested return nested
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail="Invalid gateway response", status_code=status.HTTP_502_BAD_GATEWAY,
detail="Invalid gateway response",
) )
@@ -712,7 +699,8 @@ async def update_agent_soul(
) )
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc ) from exc
reason = (payload.reason or "").strip() reason = (payload.reason or "").strip()
source_url = (payload.source_url or "").strip() source_url = (payload.source_url or "").strip()
@@ -770,9 +758,7 @@ async def ask_user_via_gateway_main(
correlation = payload.correlation_id.strip() if payload.correlation_id else "" correlation = payload.correlation_id.strip() if payload.correlation_id else ""
correlation_line = f"Correlation ID: {correlation}\n" if correlation else "" correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
preferred_channel = (payload.preferred_channel or "").strip() preferred_channel = (payload.preferred_channel or "").strip()
channel_line = ( channel_line = f"Preferred channel: {preferred_channel}\n" if preferred_channel else ""
f"Preferred channel: {preferred_channel}\n" if preferred_channel else ""
)
tags = payload.reply_tags or ["gateway_main", "user_reply"] tags = payload.reply_tags or ["gateway_main", "user_reply"]
tags_json = json.dumps(tags) tags_json = json.dumps(tags)
@@ -801,7 +787,10 @@ async def ask_user_via_gateway_main(
try: try:
await ensure_session(main_session_key, config=config, label="Main Agent") await ensure_session(main_session_key, config=config, label="Main Agent")
await send_message( await send_message(
message, session_key=main_session_key, config=config, deliver=True, message,
session_key=main_session_key,
config=config,
deliver=True,
) )
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
record_activity( record_activity(
@@ -812,7 +801,8 @@ async def ask_user_via_gateway_main(
) )
await session.commit() await session.commit()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc ) from exc
record_activity( record_activity(
@@ -867,11 +857,7 @@ async def message_gateway_board_lead(
) )
base_url = settings.base_url or "http://localhost:8000" base_url = settings.base_url or "http://localhost:8000"
header = ( header = "GATEWAY MAIN QUESTION" if payload.kind == "question" else "GATEWAY MAIN HANDOFF"
"GATEWAY MAIN QUESTION"
if payload.kind == "question"
else "GATEWAY MAIN HANDOFF"
)
correlation = payload.correlation_id.strip() if payload.correlation_id else "" correlation = payload.correlation_id.strip() if payload.correlation_id else ""
correlation_line = f"Correlation ID: {correlation}\n" if correlation else "" correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
tags = payload.reply_tags or ["gateway_main", "lead_reply"] tags = payload.reply_tags or ["gateway_main", "lead_reply"]
@@ -903,7 +889,8 @@ async def message_gateway_board_lead(
) )
await session.commit() await session.commit()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc ) from exc
record_activity( record_activity(
@@ -946,11 +933,7 @@ async def broadcast_gateway_lead_message(
boards = list(await session.exec(statement)) boards = list(await session.exec(statement))
base_url = settings.base_url or "http://localhost:8000" base_url = settings.base_url or "http://localhost:8000"
header = ( header = "GATEWAY MAIN QUESTION" if payload.kind == "question" else "GATEWAY MAIN HANDOFF"
"GATEWAY MAIN QUESTION"
if payload.kind == "question"
else "GATEWAY MAIN HANDOFF"
)
correlation = payload.correlation_id.strip() if payload.correlation_id else "" correlation = payload.correlation_id.strip() if payload.correlation_id else ""
correlation_line = f"Correlation ID: {correlation}\n" if correlation else "" correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
tags = payload.reply_tags or ["gateway_main", "lead_reply"] tags = payload.reply_tags or ["gateway_main", "lead_reply"]

View File

@@ -23,11 +23,7 @@ from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import ( from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.activity_events import ActivityEvent from app.models.activity_events import ActivityEvent
from app.models.agents import Agent from app.models.agents import Agent
from app.models.boards import Board from app.models.boards import Board
@@ -154,7 +150,8 @@ async def _require_board(
board = await Board.objects.by_id(board_id).first(session) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found", status_code=status.HTTP_404_NOT_FOUND,
detail="Board not found",
) )
if user is not None: if user is not None:
await require_board_access(session, user=user, board=board, write=write) await require_board_access(session, user=user, board=board, write=write)
@@ -162,7 +159,8 @@ async def _require_board(
async def _require_gateway( async def _require_gateway(
session: AsyncSession, board: Board, session: AsyncSession,
board: Board,
) -> tuple[Gateway, GatewayClientConfig]: ) -> tuple[Gateway, GatewayClientConfig]:
if not board.gateway_id: if not board.gateway_id:
raise HTTPException( raise HTTPException(
@@ -246,7 +244,8 @@ def _coerce_agent_items(items: Sequence[Any]) -> list[Agent]:
async def _find_gateway_for_main_session( async def _find_gateway_for_main_session(
session: AsyncSession, session_key: str | None, session: AsyncSession,
session_key: str | None,
) -> Gateway | None: ) -> Gateway | None:
if not session_key: if not session_key:
return None return None
@@ -306,7 +305,8 @@ async def _fetch_agent_events(
async def _require_user_context( async def _require_user_context(
session: AsyncSession, user: User | None, session: AsyncSession,
user: User | None,
) -> OrganizationContext: ) -> OrganizationContext:
if user is None: if user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
@@ -332,7 +332,8 @@ async def _require_agent_access(
if not is_org_admin(ctx.member): if not is_org_admin(ctx.member):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
gateway = await _find_gateway_for_main_session( gateway = await _find_gateway_for_main_session(
session, agent.openclaw_session_id, session,
agent.openclaw_session_id,
) )
if gateway is None or gateway.organization_id != ctx.organization.id: if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -355,7 +356,10 @@ def _record_heartbeat(session: AsyncSession, agent: Agent) -> None:
def _record_instruction_failure( def _record_instruction_failure(
session: AsyncSession, agent: Agent, error: str, action: str, session: AsyncSession,
agent: Agent,
error: str,
action: str,
) -> None: ) -> None:
action_label = action.replace("_", " ").capitalize() action_label = action.replace("_", " ").capitalize()
record_activity( record_activity(
@@ -432,10 +436,7 @@ async def _ensure_unique_agent_name(
if existing_gateway: if existing_gateway:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail=( detail=("An agent with this name already exists in this gateway " "workspace."),
"An agent with this name already exists in this gateway "
"workspace."
),
) )
desired_session_key = _build_session_key(requested_name) desired_session_key = _build_session_key(requested_name)
@@ -938,7 +939,9 @@ async def _commit_heartbeat(
async def _send_wakeup_message( async def _send_wakeup_message(
agent: Agent, config: GatewayClientConfig, verb: str = "provisioned", agent: Agent,
config: GatewayClientConfig,
verb: str = "provisioned",
) -> None: ) -> None:
session_key = agent.openclaw_session_id or _build_session_key(agent.name) session_key = agent.openclaw_session_id or _build_session_key(agent.name)
await ensure_session(session_key, config=config, label=agent.name) await ensure_session(session_key, config=config, label=agent.name)
@@ -971,7 +974,8 @@ async def list_agents(
col(Gateway.organization_id) == ctx.organization.id, col(Gateway.organization_id) == ctx.organization.id,
) )
base_filter = or_( base_filter = or_(
base_filter, col(Agent.openclaw_session_id).in_(gateway_keys), base_filter,
col(Agent.openclaw_session_id).in_(gateway_keys),
) )
statement = select(Agent).where(base_filter) statement = select(Agent).where(base_filter)
if board_id is not None: if board_id is not None:
@@ -987,10 +991,7 @@ async def list_agents(
def _transform(items: Sequence[Any]) -> Sequence[Any]: def _transform(items: Sequence[Any]) -> Sequence[Any]:
agents = _coerce_agent_items(items) agents = _coerce_agent_items(items)
return [ return [_to_agent_read(_with_computed_status(agent), main_session_keys) for agent in agents]
_to_agent_read(_with_computed_status(agent), main_session_keys)
for agent in agents
]
return await paginate(session, statement, transformer=_transform) return await paginate(session, statement, transformer=_transform)
@@ -1019,19 +1020,17 @@ async def stream_agents(
async with async_session_maker() as stream_session: async with async_session_maker() as stream_session:
if board_id is not None: if board_id is not None:
agents = await _fetch_agent_events( agents = await _fetch_agent_events(
stream_session, board_id, last_seen, stream_session,
board_id,
last_seen,
) )
elif allowed_ids: elif allowed_ids:
agents = await _fetch_agent_events(stream_session, None, last_seen) agents = await _fetch_agent_events(stream_session, None, last_seen)
agents = [ agents = [agent for agent in agents if agent.board_id in allowed_ids]
agent for agent in agents if agent.board_id in allowed_ids
]
else: else:
agents = [] agents = []
main_session_keys = ( main_session_keys = (
await _get_gateway_main_session_keys(stream_session) await _get_gateway_main_session_keys(stream_session) if agents else set()
if agents
else set()
) )
for agent in agents: for agent in agents:
updated_at = agent.updated_at or agent.last_seen_at or utcnow() updated_at = agent.updated_at or agent.last_seen_at or utcnow()
@@ -1252,7 +1251,8 @@ async def delete_agent(
await _require_agent_access(session, agent=agent, ctx=ctx, write=True) await _require_agent_access(session, agent=agent, ctx=ctx, write=True)
board = await _require_board( board = await _require_board(
session, str(agent.board_id) if agent.board_id else None, session,
str(agent.board_id) if agent.board_id else None,
) )
gateway, client_config = await _require_gateway(session, board) gateway, client_config = await _require_gateway(session, board)
try: try:

View File

@@ -24,12 +24,7 @@ from app.core.time import utcnow
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.models.approvals import Approval from app.models.approvals import Approval
from app.schemas.approvals import ( from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus, ApprovalUpdate
ApprovalCreate,
ApprovalRead,
ApprovalStatus,
ApprovalUpdate,
)
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -156,9 +151,7 @@ async def stream_approvals(
).one(), ).one(),
) )
task_ids = { task_ids = {
approval.task_id approval.task_id for approval in approvals if approval.task_id is not None
for approval in approvals
if approval.task_id is not None
} }
counts_by_task_id: dict[UUID, tuple[int, int]] = {} counts_by_task_id: dict[UUID, tuple[int, int]] = {}
if task_ids: if task_ids:

View File

@@ -26,21 +26,14 @@ from app.core.time import utcnow
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import ( from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent from app.models.agents import Agent
from app.models.board_group_memory import BoardGroupMemory from app.models.board_group_memory import BoardGroupMemory
from app.models.board_groups import BoardGroup from app.models.board_groups import BoardGroup
from app.models.boards import Board from app.models.boards import Board
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.models.users import User from app.models.users import User
from app.schemas.board_group_memory import ( from app.schemas.board_group_memory import BoardGroupMemoryCreate, BoardGroupMemoryRead
BoardGroupMemoryCreate,
BoardGroupMemoryRead,
)
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.mentions import extract_mentions, matches_agent_mention from app.services.mentions import extract_mentions, matches_agent_mention
from app.services.organizations import ( from app.services.organizations import (

View File

@@ -10,12 +10,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col, select
from app.api.deps import ( from app.api.deps import ActorContext, require_admin_or_agent, require_org_admin, require_org_member
ActorContext,
require_admin_or_agent,
require_org_admin,
require_org_member,
)
from app.core.time import utcnow from app.core.time import utcnow
from app.db import crud from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
@@ -34,10 +29,7 @@ from app.schemas.board_groups import BoardGroupCreate, BoardGroupRead, BoardGrou
from app.schemas.common import OkResponse from app.schemas.common import OkResponse
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.schemas.view_models import BoardGroupSnapshot from app.schemas.view_models import BoardGroupSnapshot
from app.services.agent_provisioning import ( from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, sync_gateway_agent_heartbeats
DEFAULT_HEARTBEAT_CONFIG,
sync_gateway_agent_heartbeats,
)
from app.services.board_group_snapshot import build_group_snapshot from app.services.board_group_snapshot import build_group_snapshot
from app.services.organizations import ( from app.services.organizations import (
OrganizationContext, OrganizationContext,
@@ -86,8 +78,7 @@ async def _require_group_access(
return group return group
board_ids = [ board_ids = [
board.id board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session)
for board in await Board.objects.filter_by(board_group_id=group_id).all(session)
] ]
if not board_ids: if not board_ids:
if is_org_admin(member): if is_org_admin(member):
@@ -144,7 +135,10 @@ async def get_board_group(
) -> BoardGroup: ) -> BoardGroup:
"""Get a board group by id.""" """Get a board group by id."""
return await _require_group_access( return await _require_group_access(
session, group_id=group_id, member=ctx.member, write=False, session,
group_id=group_id,
member=ctx.member,
write=False,
) )
@@ -159,7 +153,10 @@ async def get_board_group_snapshot(
) -> BoardGroupSnapshot: ) -> BoardGroupSnapshot:
"""Get a snapshot across boards in a group.""" """Get a snapshot across boards in a group."""
group = await _require_group_access( group = await _require_group_access(
session, group_id=group_id, member=ctx.member, write=False, session,
group_id=group_id,
member=ctx.member,
write=False,
) )
if per_board_task_limit < 0: if per_board_task_limit < 0:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
@@ -174,9 +171,7 @@ async def get_board_group_snapshot(
allowed_ids = set( allowed_ids = set(
await list_accessible_board_ids(session, member=ctx.member, write=False), await list_accessible_board_ids(session, member=ctx.member, write=False),
) )
snapshot.boards = [ snapshot.boards = [item for item in snapshot.boards if item.board.id in allowed_ids]
item for item in snapshot.boards if item.board.id in allowed_ids
]
return snapshot return snapshot
@@ -339,14 +334,13 @@ async def update_board_group(
) -> BoardGroup: ) -> BoardGroup:
"""Update a board group.""" """Update a board group."""
group = await _require_group_access( group = await _require_group_access(
session, group_id=group_id, member=ctx.member, write=True, session,
group_id=group_id,
member=ctx.member,
write=True,
) )
updates = payload.model_dump(exclude_unset=True) updates = payload.model_dump(exclude_unset=True)
if ( if "slug" in updates and updates["slug"] is not None and not updates["slug"].strip():
"slug" in updates
and updates["slug"] is not None
and not updates["slug"].strip()
):
updates["slug"] = _slugify(updates.get("name") or group.name) updates["slug"] = _slugify(updates.get("name") or group.name)
updates["updated_at"] = utcnow() updates["updated_at"] = utcnow()
return await crud.patch(session, group, updates) return await crud.patch(session, group, updates)
@@ -360,7 +354,10 @@ async def delete_board_group(
) -> OkResponse: ) -> OkResponse:
"""Delete a board group.""" """Delete a board group."""
await _require_group_access( await _require_group_access(
session, group_id=group_id, member=ctx.member, write=True, session,
group_id=group_id,
member=ctx.member,
write=True,
) )
# Boards reference groups, so clear the FK first to keep deletes simple. # Boards reference groups, so clear the FK first to keep deletes simple.
@@ -378,7 +375,10 @@ async def delete_board_group(
commit=False, commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, BoardGroup, col(BoardGroup.id) == group_id, commit=False, session,
BoardGroup,
col(BoardGroup.id) == group_id,
commit=False,
) )
await session.commit() await session.commit()
return OkResponse() return OkResponse()

View File

@@ -24,11 +24,7 @@ from app.core.time import utcnow
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import ( from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent from app.models.agents import Agent
from app.models.board_memory import BoardMemory from app.models.board_memory import BoardMemory
from app.models.gateways import Gateway from app.models.gateways import Gateway

View File

@@ -21,11 +21,7 @@ from app.core.config import settings
from app.core.time import utcnow from app.core.time import utcnow
from app.db.session import get_session from app.db.session import get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import ( from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.board_onboarding import BoardOnboardingSession from app.models.board_onboarding import BoardOnboardingSession
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.schemas.board_onboarding import ( from app.schemas.board_onboarding import (
@@ -39,11 +35,7 @@ from app.schemas.board_onboarding import (
BoardOnboardingUserProfile, BoardOnboardingUserProfile,
) )
from app.schemas.boards import BoardRead from app.schemas.boards import BoardRead
from app.services.board_leads import ( from app.services.board_leads import LeadAgentOptions, LeadAgentRequest, ensure_board_lead_agent
LeadAgentOptions,
LeadAgentRequest,
ensure_board_lead_agent,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -62,7 +54,8 @@ ADMIN_AUTH_DEP = Depends(require_admin_auth)
async def _gateway_config( async def _gateway_config(
session: AsyncSession, board: Board, session: AsyncSession,
board: Board,
) -> tuple[Gateway, GatewayClientConfig]: ) -> tuple[Gateway, GatewayClientConfig]:
if not board.gateway_id: if not board.gateway_id:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
@@ -255,11 +248,15 @@ async def start_onboarding(
try: try:
await ensure_session(session_key, config=config, label="Main Agent") await ensure_session(session_key, config=config, label="Main Agent")
await send_message( await send_message(
prompt, session_key=session_key, config=config, deliver=False, prompt,
session_key=session_key,
config=config,
deliver=False,
) )
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc ) from exc
onboarding = BoardOnboardingSession( onboarding = BoardOnboardingSession(
@@ -311,7 +308,8 @@ async def answer_onboarding(
) )
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc ) from exc
onboarding.messages = messages onboarding.messages = messages

View File

@@ -104,7 +104,9 @@ async def _require_gateway_for_create(
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
) -> Gateway: ) -> Gateway:
return await _require_gateway( return await _require_gateway(
session, payload.gateway_id, organization_id=ctx.organization.id, session,
payload.gateway_id,
organization_id=ctx.organization.id,
) )
@@ -155,7 +157,9 @@ async def _apply_board_update(
updates = payload.model_dump(exclude_unset=True) updates = payload.model_dump(exclude_unset=True)
if "gateway_id" in updates: if "gateway_id" in updates:
await _require_gateway( await _require_gateway(
session, updates["gateway_id"], organization_id=board.organization_id, session,
updates["gateway_id"],
organization_id=board.organization_id,
) )
if "board_group_id" in updates and updates["board_group_id"] is not None: if "board_group_id" in updates and updates["board_group_id"] is not None:
await _require_board_group( await _require_board_group(
@@ -164,10 +168,7 @@ async def _apply_board_update(
organization_id=board.organization_id, organization_id=board.organization_id,
) )
crud.apply_updates(board, updates) crud.apply_updates(board, updates)
if ( if updates.get("board_type") == "goal" and (not board.objective or not board.success_metrics):
updates.get("board_type") == "goal"
and (not board.objective or not board.success_metrics)
):
# Validate only when explicitly switching to goal boards. # Validate only when explicitly switching to goal boards.
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -183,7 +184,8 @@ async def _apply_board_update(
async def _board_gateway( async def _board_gateway(
session: AsyncSession, board: Board, session: AsyncSession,
board: Board,
) -> tuple[Gateway | None, GatewayClientConfig | None]: ) -> tuple[Gateway | None, GatewayClientConfig | None]:
if not board.gateway_id: if not board.gateway_id:
return None, None return None, None
@@ -255,7 +257,8 @@ async def list_boards(
if board_group_id is not None: if board_group_id is not None:
statement = statement.where(col(Board.board_group_id) == board_group_id) statement = statement.where(col(Board.board_group_id) == board_group_id)
statement = statement.order_by( statement = statement.order_by(
func.lower(col(Board.name)).asc(), col(Board.created_at).desc(), func.lower(col(Board.name)).asc(),
col(Board.created_at).desc(),
) )
return await paginate(session, statement) return await paginate(session, statement)
@@ -350,10 +353,14 @@ async def delete_board(
commit=False, commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, TaskDependency, col(TaskDependency.board_id) == board.id, session,
TaskDependency,
col(TaskDependency.board_id) == board.id,
) )
await crud.delete_where( await crud.delete_where(
session, TaskFingerprint, col(TaskFingerprint.board_id) == board.id, session,
TaskFingerprint,
col(TaskFingerprint.board_id) == board.id,
) )
# Approvals can reference tasks and agents, so delete before both. # Approvals can reference tasks and agents, so delete before both.

View File

@@ -91,7 +91,8 @@ async def _resolve_gateway(
board = await Board.objects.by_id(params.board_id).first(session) board = await Board.objects.by_id(params.board_id).first(session)
if board is None: if board is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found", status_code=status.HTTP_404_NOT_FOUND,
detail="Board not found",
) )
if user is not None: if user is not None:
await require_board_access(session, user=user, board=board, write=False) await require_board_access(session, user=user, board=board, write=False)
@@ -119,7 +120,10 @@ async def _resolve_gateway(
async def _require_gateway( async def _require_gateway(
session: AsyncSession, board_id: str | None, *, user: User | None = None, session: AsyncSession,
board_id: str | None,
*,
user: User | None = None,
) -> tuple[Board, GatewayClientConfig, str | None]: ) -> tuple[Board, GatewayClientConfig, str | None]:
params = GatewayResolveQuery(board_id=board_id) params = GatewayResolveQuery(board_id=board_id)
board, config, main_session = await _resolve_gateway( board, config, main_session = await _resolve_gateway(
@@ -161,7 +165,9 @@ async def gateways_status(
if main_session: if main_session:
try: try:
ensured = await ensure_session( ensured = await ensure_session(
main_session, config=config, label="Main Agent", main_session,
config=config,
label="Main Agent",
) )
if isinstance(ensured, dict): if isinstance(ensured, dict):
main_session_entry = ensured.get("entry") or ensured main_session_entry = ensured.get("entry") or ensured
@@ -178,7 +184,9 @@ async def gateways_status(
) )
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
return GatewaysStatusResponse( return GatewaysStatusResponse(
connected=False, gateway_url=config.url, error=str(exc), connected=False,
gateway_url=config.url,
error=str(exc),
) )
@@ -202,7 +210,8 @@ async def list_gateway_sessions(
sessions = await openclaw_call("sessions.list", config=config) sessions = await openclaw_call("sessions.list", config=config)
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc ) from exc
if isinstance(sessions, dict): if isinstance(sessions, dict):
sessions_list = _as_object_list(sessions.get("sessions")) sessions_list = _as_object_list(sessions.get("sessions"))
@@ -213,7 +222,9 @@ async def list_gateway_sessions(
if main_session: if main_session:
try: try:
ensured = await ensure_session( ensured = await ensure_session(
main_session, config=config, label="Main Agent", main_session,
config=config,
label="Main Agent",
) )
if isinstance(ensured, dict): if isinstance(ensured, dict):
main_session_entry = ensured.get("entry") or ensured main_session_entry = ensured.get("entry") or ensured
@@ -233,11 +244,7 @@ async def _list_sessions(config: GatewayClientConfig) -> list[dict[str, object]]
raw_items = _as_object_list(sessions.get("sessions")) raw_items = _as_object_list(sessions.get("sessions"))
else: else:
raw_items = _as_object_list(sessions) raw_items = _as_object_list(sessions)
return [ return [item for item in raw_items if isinstance(item, dict)]
item
for item in raw_items
if isinstance(item, dict)
]
async def _with_main_session( async def _with_main_session(
@@ -246,9 +253,7 @@ async def _with_main_session(
config: GatewayClientConfig, config: GatewayClientConfig,
main_session: str | None, main_session: str | None,
) -> list[dict[str, object]]: ) -> list[dict[str, object]]:
if not main_session or any( if not main_session or any(item.get("key") == main_session for item in sessions_list):
item.get("key") == main_session for item in sessions_list
):
return sessions_list return sessions_list
try: try:
await ensure_session(main_session, config=config, label="Main Agent") await ensure_session(main_session, config=config, label="Main Agent")
@@ -278,7 +283,8 @@ async def get_gateway_session(
sessions_list = await _list_sessions(config) sessions_list = await _list_sessions(config)
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc ) from exc
sessions_list = await _with_main_session( sessions_list = await _with_main_session(
sessions_list, sessions_list,
@@ -286,12 +292,15 @@ async def get_gateway_session(
main_session=main_session, main_session=main_session,
) )
session_entry = next( session_entry = next(
(item for item in sessions_list if item.get("key") == session_id), None, (item for item in sessions_list if item.get("key") == session_id),
None,
) )
if session_entry is None and main_session and session_id == main_session: if session_entry is None and main_session and session_id == main_session:
try: try:
ensured = await ensure_session( ensured = await ensure_session(
main_session, config=config, label="Main Agent", main_session,
config=config,
label="Main Agent",
) )
if isinstance(ensured, dict): if isinstance(ensured, dict):
session_entry = ensured.get("entry") or ensured session_entry = ensured.get("entry") or ensured
@@ -299,13 +308,15 @@ async def get_gateway_session(
session_entry = None session_entry = None
if session_entry is None: if session_entry is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found", status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found",
) )
return GatewaySessionResponse(session=session_entry) return GatewaySessionResponse(session=session_entry)
@router.get( @router.get(
"/sessions/{session_id}/history", response_model=GatewaySessionHistoryResponse, "/sessions/{session_id}/history",
response_model=GatewaySessionHistoryResponse,
) )
async def get_session_history( async def get_session_history(
session_id: str, session_id: str,
@@ -322,7 +333,8 @@ async def get_session_history(
history = await get_chat_history(session_id, config=config) history = await get_chat_history(session_id, config=config)
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc ) from exc
if isinstance(history, dict) and isinstance(history.get("messages"), list): if isinstance(history, dict) and isinstance(history.get("messages"), list):
return GatewaySessionHistoryResponse(history=history["messages"]) return GatewaySessionHistoryResponse(history=history["messages"])
@@ -339,7 +351,9 @@ async def send_gateway_session_message(
) -> OkResponse: ) -> OkResponse:
"""Send a message into a specific gateway session.""" """Send a message into a specific gateway session."""
board, config, main_session = await _require_gateway( board, config, main_session = await _require_gateway(
session, board_id, user=auth.user, session,
board_id,
user=auth.user,
) )
if auth.user is None: if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
@@ -350,7 +364,8 @@ async def send_gateway_session_message(
await send_message(payload.content, session_key=session_id, config=config) await send_message(payload.content, session_key=session_id, config=config)
except OpenClawGatewayError as exc: except OpenClawGatewayError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc), status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(exc),
) from exc ) from exc
return OkResponse() return OkResponse()

View File

@@ -17,11 +17,7 @@ from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import get_session from app.db.session import get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import ( from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent from app.models.agents import Agent
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.schemas.common import OkResponse from app.schemas.common import OkResponse
@@ -38,12 +34,8 @@ from app.services.agent_provisioning import (
ProvisionOptions, ProvisionOptions,
provision_main_agent, provision_main_agent,
) )
from app.services.template_sync import ( from app.services.template_sync import GatewayTemplateSyncOptions
GatewayTemplateSyncOptions, from app.services.template_sync import sync_gateway_templates as sync_gateway_templates_service
)
from app.services.template_sync import (
sync_gateway_templates as sync_gateway_templates_service,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi_pagination.limit_offset import LimitOffsetPage from fastapi_pagination.limit_offset import LimitOffsetPage
@@ -109,7 +101,8 @@ async def _require_gateway(
) )
if gateway is None: if gateway is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found", status_code=status.HTTP_404_NOT_FOUND,
detail="Gateway not found",
) )
return gateway return gateway

View File

@@ -8,8 +8,9 @@ from typing import Literal
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy import DateTime, case, func from sqlalchemy import DateTime, case
from sqlalchemy import cast as sql_cast from sqlalchemy import cast as sql_cast
from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession

View File

@@ -80,7 +80,8 @@ ORG_ADMIN_DEP = Depends(require_org_admin)
def _member_to_read( def _member_to_read(
member: OrganizationMember, user: User | None, member: OrganizationMember,
user: User | None,
) -> OrganizationMemberRead: ) -> OrganizationMemberRead:
model = OrganizationMemberRead.model_validate(member, from_attributes=True) model = OrganizationMemberRead.model_validate(member, from_attributes=True)
if user is not None: if user is not None:
@@ -167,9 +168,7 @@ async def list_my_organizations(
await get_active_membership(session, auth.user) await get_active_membership(session, auth.user)
db_user = await User.objects.by_id(auth.user.id).first(session) db_user = await User.objects.by_id(auth.user.id).first(session)
active_id = ( active_id = db_user.active_organization_id if db_user else auth.user.active_organization_id
db_user.active_organization_id if db_user else auth.user.active_organization_id
)
statement = ( statement = (
select(Organization, OrganizationMember) select(Organization, OrganizationMember)
@@ -202,7 +201,9 @@ async def set_active_org(
if auth.user is None: if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
member = await set_active_organization( member = await set_active_organization(
session, user=auth.user, organization_id=payload.organization_id, session,
user=auth.user,
organization_id=payload.organization_id,
) )
organization = await Organization.objects.by_id(member.organization_id).first( organization = await Organization.objects.by_id(member.organization_id).first(
session, session,
@@ -245,7 +246,10 @@ async def delete_my_org(
group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id) group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id)
await crud.delete_where( await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False, session,
ActivityEvent,
col(ActivityEvent.task_id).in_(task_ids),
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, session,
@@ -266,10 +270,16 @@ async def delete_my_org(
commit=False, commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, Approval, col(Approval.board_id).in_(board_ids), commit=False, session,
Approval,
col(Approval.board_id).in_(board_ids),
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, BoardMemory, col(BoardMemory.board_id).in_(board_ids), commit=False, session,
BoardMemory,
col(BoardMemory.board_id).in_(board_ids),
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, session,
@@ -302,13 +312,22 @@ async def delete_my_org(
commit=False, commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, Task, col(Task.board_id).in_(board_ids), commit=False, session,
Task,
col(Task.board_id).in_(board_ids),
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, Agent, col(Agent.board_id).in_(board_ids), commit=False, session,
Agent,
col(Agent.board_id).in_(board_ids),
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, Board, col(Board.organization_id) == org_id, commit=False, session,
Board,
col(Board.organization_id) == org_id,
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, session,
@@ -317,10 +336,16 @@ async def delete_my_org(
commit=False, commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, BoardGroup, col(BoardGroup.organization_id) == org_id, commit=False, session,
BoardGroup,
col(BoardGroup.organization_id) == org_id,
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, Gateway, col(Gateway.organization_id) == org_id, commit=False, session,
Gateway,
col(Gateway.organization_id) == org_id,
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, session,
@@ -342,7 +367,10 @@ async def delete_my_org(
commit=False, commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, Organization, col(Organization.id) == org_id, commit=False, session,
Organization,
col(Organization.id) == org_id,
commit=False,
) )
await session.commit() await session.commit()
return OkResponse() return OkResponse()
@@ -360,14 +388,14 @@ async def get_my_membership(
).all(session) ).all(session)
model = _member_to_read(ctx.member, user) model = _member_to_read(ctx.member, user)
model.board_access = [ model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
for row in access_rows
] ]
return model return model
@router.get( @router.get(
"/me/members", response_model=DefaultLimitOffsetPage[OrganizationMemberRead], "/me/members",
response_model=DefaultLimitOffsetPage[OrganizationMemberRead],
) )
async def list_org_members( async def list_org_members(
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
@@ -410,8 +438,7 @@ async def get_org_member(
).all(session) ).all(session)
model = _member_to_read(member, user) model = _member_to_read(member, user)
model.board_access = [ model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
for row in access_rows
] ]
return model return model
@@ -529,9 +556,7 @@ async def remove_org_member(
user.active_organization_id = fallback_membership user.active_organization_id = fallback_membership
else: else:
user.active_organization_id = ( user.active_organization_id = (
fallback_membership.organization_id fallback_membership.organization_id if fallback_membership is not None else None
if fallback_membership is not None
else None
) )
session.add(user) session.add(user)
@@ -540,7 +565,8 @@ async def remove_org_member(
@router.get( @router.get(
"/me/invites", response_model=DefaultLimitOffsetPage[OrganizationInviteRead], "/me/invites",
response_model=DefaultLimitOffsetPage[OrganizationInviteRead],
) )
async def list_org_invites( async def list_org_invites(
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
@@ -607,7 +633,9 @@ async def create_org_invite(
if valid_board_ids != board_ids: if valid_board_ids != board_ids:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
await apply_invite_board_access( await apply_invite_board_access(
session, invite=invite, entries=payload.board_access, session,
invite=invite,
entries=payload.board_access,
) )
await session.commit() await session.commit()
await session.refresh(invite) await session.refresh(invite)

View File

@@ -29,11 +29,7 @@ from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import ( from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.activity_events import ActivityEvent from app.models.activity_events import ActivityEvent
from app.models.agents import Agent from app.models.agents import Agent
from app.models.approvals import Approval from app.models.approvals import Approval
@@ -45,13 +41,7 @@ from app.models.tasks import Task
from app.schemas.common import OkResponse from app.schemas.common import OkResponse
from app.schemas.errors import BlockedTaskError from app.schemas.errors import BlockedTaskError
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.schemas.tasks import ( from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
TaskCommentCreate,
TaskCommentRead,
TaskCreate,
TaskRead,
TaskUpdate,
)
from app.services.activity_log import record_activity from app.services.activity_log import record_activity
from app.services.mentions import extract_mentions, matches_agent_mention from app.services.mentions import extract_mentions, matches_agent_mention
from app.services.organizations import require_board_access from app.services.organizations import require_board_access
@@ -263,8 +253,7 @@ async def _reconcile_dependents_for_dependency_toggle(
event_type="task.status_changed", event_type="task.status_changed",
task_id=dependent.id, task_id=dependent.id,
message=( message=(
"Task returned to inbox: dependency reopened " "Task returned to inbox: dependency reopened " f"({dependency_task.title})."
f"({dependency_task.title})."
), ),
agent_id=actor_agent_id, agent_id=actor_agent_id,
) )
@@ -313,7 +302,8 @@ def _serialize_comment(event: ActivityEvent) -> dict[str, object]:
async def _gateway_config( async def _gateway_config(
session: AsyncSession, board: Board, session: AsyncSession,
board: Board,
) -> GatewayClientConfig | None: ) -> GatewayClientConfig | None:
if not board.gateway_id: if not board.gateway_id:
return None return None
@@ -368,10 +358,7 @@ async def _notify_agent_on_task_assign(
message = ( message = (
"TASK ASSIGNED\n" "TASK ASSIGNED\n"
+ "\n".join(details) + "\n".join(details)
+ ( + ("\n\nTake action: open the task and begin work. " "Post updates as task comments.")
"\n\nTake action: open the task and begin work. "
"Post updates as task comments."
)
) )
try: try:
await _send_agent_task_message( await _send_agent_task_message(
@@ -607,9 +594,7 @@ async def _stream_dependency_state(
rows: list[tuple[ActivityEvent, Task | None]], rows: list[tuple[ActivityEvent, Task | None]],
) -> tuple[dict[UUID, list[UUID]], dict[UUID, str]]: ) -> tuple[dict[UUID, list[UUID]], dict[UUID, str]]:
task_ids = [ task_ids = [
task.id task.id for event, task in rows if task is not None and event.event_type != "task.comment"
for event, task in rows
if task is not None and event.event_type != "task.comment"
] ]
if not task_ids: if not task_ids:
return {}, {} return {}, {}
@@ -786,7 +771,8 @@ async def create_task(
dependency_ids=normalized_deps, dependency_ids=normalized_deps,
) )
blocked_by = blocked_by_dependency_ids( blocked_by = blocked_by_dependency_ids(
dependency_ids=normalized_deps, status_by_id=dep_status, dependency_ids=normalized_deps,
status_by_id=dep_status,
) )
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"): if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
raise _blocked_task_error(blocked_by) raise _blocked_task_error(blocked_by)
@@ -861,9 +847,7 @@ async def update_task(
updates = payload.model_dump(exclude_unset=True) updates = payload.model_dump(exclude_unset=True)
comment = payload.comment if "comment" in payload.model_fields_set else None comment = payload.comment if "comment" in payload.model_fields_set else None
depends_on_task_ids = ( depends_on_task_ids = (
payload.depends_on_task_ids payload.depends_on_task_ids if "depends_on_task_ids" in payload.model_fields_set else None
if "depends_on_task_ids" in payload.model_fields_set
else None
) )
updates.pop("comment", None) updates.pop("comment", None)
updates.pop("depends_on_task_ids", None) updates.pop("depends_on_task_ids", None)
@@ -906,13 +890,22 @@ async def delete_task(
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
await require_board_access(session, user=auth.user, board=board, write=True) await require_board_access(session, user=auth.user, board=board, write=True)
await crud.delete_where( await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.task_id) == task.id, commit=False, session,
ActivityEvent,
col(ActivityEvent.task_id) == task.id,
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, TaskFingerprint, col(TaskFingerprint.task_id) == task.id, commit=False, session,
TaskFingerprint,
col(TaskFingerprint.task_id) == task.id,
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, Approval, col(Approval.task_id) == task.id, commit=False, session,
Approval,
col(Approval.task_id) == task.id,
commit=False,
) )
await crud.delete_where( await crud.delete_where(
session, session,
@@ -929,7 +922,8 @@ async def delete_task(
@router.get( @router.get(
"/{task_id}/comments", response_model=DefaultLimitOffsetPage[TaskCommentRead], "/{task_id}/comments",
response_model=DefaultLimitOffsetPage[TaskCommentRead],
) )
async def list_task_comments( async def list_task_comments(
task: Task = TASK_DEP, task: Task = TASK_DEP,
@@ -1241,11 +1235,7 @@ async def _lead_apply_assignment(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Board leads cannot assign tasks to themselves.", detail="Board leads cannot assign tasks to themselves.",
) )
if ( if agent.board_id and update.task.board_id and agent.board_id != update.task.board_id:
agent.board_id
and update.task.board_id
and agent.board_id != update.task.board_id
):
raise HTTPException(status_code=status.HTTP_409_CONFLICT) raise HTTPException(status_code=status.HTTP_409_CONFLICT)
update.task.assigned_agent_id = agent.id update.task.assigned_agent_id = agent.id
@@ -1256,19 +1246,13 @@ def _lead_apply_status(update: _TaskUpdateInput) -> None:
if update.task.status != "review": if update.task.status != "review":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=( detail=("Board leads can only change status when a task is " "in review."),
"Board leads can only change status when a task is "
"in review."
),
) )
target_status = _required_status_value(update.updates["status"]) target_status = _required_status_value(update.updates["status"])
if target_status not in {"done", "inbox"}: if target_status not in {"done", "inbox"}:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=( detail=("Board leads can only move review tasks to done " "or inbox."),
"Board leads can only move review tasks to done "
"or inbox."
),
) )
if target_status == "inbox": if target_status == "inbox":
update.task.assigned_agent_id = None update.task.assigned_agent_id = None
@@ -1397,9 +1381,7 @@ async def _apply_non_lead_agent_task_rules(
update.task.assigned_agent_id = None update.task.assigned_agent_id = None
update.task.in_progress_at = None update.task.in_progress_at = None
else: else:
update.task.assigned_agent_id = ( update.task.assigned_agent_id = update.actor.agent.id if update.actor.agent else None
update.actor.agent.id if update.actor.agent else None
)
if status_value == "in_progress": if status_value == "in_progress":
update.task.in_progress_at = utcnow() update.task.in_progress_at = utcnow()
@@ -1462,11 +1444,7 @@ async def _apply_admin_task_rules(
agent = await Agent.objects.by_id(assigned_agent_id).first(session) agent = await Agent.objects.by_id(assigned_agent_id).first(session)
if agent is None: if agent is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if ( if agent.board_id and update.task.board_id and agent.board_id != update.task.board_id:
agent.board_id
and update.task.board_id
and agent.board_id != update.task.board_id
):
raise HTTPException(status_code=status.HTTP_409_CONFLICT) raise HTTPException(status_code=status.HTTP_409_CONFLICT)
@@ -1481,9 +1459,11 @@ async def _record_task_comment_from_update(
event_type="task.comment", event_type="task.comment",
message=update.comment, message=update.comment,
task_id=update.task.id, task_id=update.task.id,
agent_id=update.actor.agent.id agent_id=(
update.actor.agent.id
if update.actor.actor_type == "agent" and update.actor.agent if update.actor.actor_type == "agent" and update.actor.agent
else None, else None
),
) )
session.add(event) session.add(event)
await session.commit() await session.commit()
@@ -1496,9 +1476,7 @@ async def _record_task_update_activity(
) -> None: ) -> None:
event_type, message = _task_event_details(update.task, update.previous_status) event_type, message = _task_event_details(update.task, update.previous_status)
actor_agent_id = ( actor_agent_id = (
update.actor.agent.id update.actor.agent.id if update.actor.actor_type == "agent" and update.actor.agent else None
if update.actor.actor_type == "agent" and update.actor.agent
else None
) )
record_activity( record_activity(
session, session,
@@ -1525,10 +1503,7 @@ async def _notify_task_update_assignment_changes(
if ( if (
update.task.status == "inbox" update.task.status == "inbox"
and update.task.assigned_agent_id is None and update.task.assigned_agent_id is None
and ( and (update.previous_status != "inbox" or update.previous_assigned is not None)
update.previous_status != "inbox"
or update.previous_assigned is not None
)
): ):
board = ( board = (
await Board.objects.by_id(update.task.board_id).first(session) await Board.objects.by_id(update.task.board_id).first(session)

View File

@@ -77,10 +77,7 @@ async def _touch_agent_presence(
real activity even if the heartbeat loop isn't running. real activity even if the heartbeat loop isn't running.
""" """
now = utcnow() now = utcnow()
if ( if agent.last_seen_at is not None and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL:
agent.last_seen_at is not None
and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL
):
return return
agent.last_seen_at = now agent.last_seen_at = now

View File

@@ -19,9 +19,7 @@ _MULTIPLIERS: dict[str, int] = {
_MAX_SCHEDULE_SECONDS = 60 * 60 * 24 * 365 * 10 _MAX_SCHEDULE_SECONDS = 60 * 60 * 24 * 365 * 10
_ERR_SCHEDULE_REQUIRED = "schedule is required" _ERR_SCHEDULE_REQUIRED = "schedule is required"
_ERR_SCHEDULE_INVALID = ( _ERR_SCHEDULE_INVALID = 'Invalid schedule. Expected format like "10m", "1h", "2d", "1w".'
'Invalid schedule. Expected format like "10m", "1h", "2d", "1w".'
)
_ERR_SCHEDULE_NONPOSITIVE = "Schedule must be greater than 0." _ERR_SCHEDULE_NONPOSITIVE = "Schedule must be greater than 0."
_ERR_SCHEDULE_TOO_LARGE = "Schedule is too large (max 10 years)." _ERR_SCHEDULE_TOO_LARGE = "Schedule is too large (max 10 years)."

View File

@@ -13,7 +13,7 @@ from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.responses import Response from starlette.responses import Response
if TYPE_CHECKING: if TYPE_CHECKING: # pragma: no cover
from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.types import ASGIApp, Message, Receive, Scope, Send
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -44,9 +44,7 @@ class RequestIdMiddleware:
if message["type"] == "http.response.start": if message["type"] == "http.response.start":
# Starlette uses `list[tuple[bytes, bytes]]` here. # Starlette uses `list[tuple[bytes, bytes]]` here.
headers: list[tuple[bytes, bytes]] = message.setdefault("headers", []) headers: list[tuple[bytes, bytes]] = message.setdefault("headers", [])
if not any( if not any(key.lower() == self._header_name_bytes for key, _ in headers):
key.lower() == self._header_name_bytes for key, _ in headers
):
request_id_bytes = request_id.encode("latin-1") request_id_bytes = request_id.encode("latin-1")
headers.append((self._header_name_bytes, request_id_bytes)) headers.append((self._header_name_bytes, request_id_bytes))
await send(message) await send(message)

View File

@@ -198,8 +198,7 @@ class AppLogger:
formatter: logging.Formatter = JsonFormatter() formatter: logging.Formatter = JsonFormatter()
else: else:
formatter = KeyValueFormatter( formatter = KeyValueFormatter(
"%(asctime)s %(levelname)s %(name)s %(message)s " "%(asctime)s %(levelname)s %(name)s %(message)s " "app=%(app)s version=%(version)s",
"app=%(app)s version=%(version)s",
) )
if settings.log_use_utc: if settings.log_use_utc:
formatter.converter = time.gmtime formatter.converter = time.gmtime

View File

@@ -161,10 +161,7 @@ async def list_by(
async def exists(session: AsyncSession, model: type[ModelT], **lookup: object) -> bool: async def exists(session: AsyncSession, model: type[ModelT], **lookup: object) -> bool:
"""Return whether any object exists for lookup values.""" """Return whether any object exists for lookup values."""
return ( return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
(await session.exec(_lookup_statement(model, lookup).limit(1))).first()
is not None
)
def _criteria_statement( def _criteria_statement(
@@ -219,11 +216,7 @@ async def update_where(
commit = bool(options.pop("commit", False)) commit = bool(options.pop("commit", False))
exclude_none = bool(options.pop("exclude_none", False)) exclude_none = bool(options.pop("exclude_none", False))
allowed_fields_raw = options.pop("allowed_fields", None) allowed_fields_raw = options.pop("allowed_fields", None)
allowed_fields = ( allowed_fields = allowed_fields_raw if isinstance(allowed_fields_raw, set) else None
allowed_fields_raw
if isinstance(allowed_fields_raw, set)
else None
)
source_updates: dict[str, Any] = {} source_updates: dict[str, Any] = {}
if updates: if updates:
source_updates.update(dict(updates)) source_updates.update(dict(updates))
@@ -276,11 +269,7 @@ async def patch(
"""Apply partial updates and persist object.""" """Apply partial updates and persist object."""
exclude_none = bool(options.pop("exclude_none", False)) exclude_none = bool(options.pop("exclude_none", False))
allowed_fields_raw = options.pop("allowed_fields", None) allowed_fields_raw = options.pop("allowed_fields", None)
allowed_fields = ( allowed_fields = allowed_fields_raw if isinstance(allowed_fields_raw, set) else None
allowed_fields_raw
if isinstance(allowed_fields_raw, set)
else None
)
commit = bool(options.pop("commit", True)) commit = bool(options.pop("commit", True))
refresh = bool(options.pop("refresh", True)) refresh = bool(options.pop("refresh", True))
apply_updates( apply_updates(

View File

@@ -36,12 +36,8 @@ class BoardOnboardingConfirm(SQLModel):
@model_validator(mode="after") @model_validator(mode="after")
def validate_goal_fields(self) -> Self: def validate_goal_fields(self) -> Self:
"""Require goal metadata when the board type is `goal`.""" """Require goal metadata when the board type is `goal`."""
if self.board_type == "goal" and ( if self.board_type == "goal" and (not self.objective or not self.success_metrics):
not self.objective or not self.success_metrics message = "Confirmed goal boards require objective and success_metrics"
):
message = (
"Confirmed goal boards require objective and success_metrics"
)
raise ValueError(message) raise ValueError(message)
return self return self

View File

@@ -9,9 +9,7 @@ from uuid import UUID
from pydantic import model_validator from pydantic import model_validator
from sqlmodel import SQLModel from sqlmodel import SQLModel
_ERR_GOAL_FIELDS_REQUIRED = ( _ERR_GOAL_FIELDS_REQUIRED = "Confirmed goal boards require objective and success_metrics"
"Confirmed goal boards require objective and success_metrics"
)
_ERR_GATEWAY_REQUIRED = "gateway_id is required" _ERR_GATEWAY_REQUIRED = "gateway_id is required"
RUNTIME_ANNOTATION_TYPES = (datetime, UUID) RUNTIME_ANNOTATION_TYPES = (datetime, UUID)

View File

@@ -15,11 +15,7 @@ from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoes
from app.core.config import settings from app.core.config import settings
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import ( from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, openclaw_call
OpenClawGatewayError,
ensure_session,
openclaw_call,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from app.models.agents import Agent from app.models.agents import Agent
@@ -414,7 +410,9 @@ async def _supported_gateway_files(config: GatewayClientConfig) -> set[str]:
if not agent_id: if not agent_id:
return set(DEFAULT_GATEWAY_FILES) return set(DEFAULT_GATEWAY_FILES)
files_payload = await openclaw_call( files_payload = await openclaw_call(
"agents.files.list", {"agentId": agent_id}, config=config, "agents.files.list",
{"agentId": agent_id},
config=config,
) )
if isinstance(files_payload, dict): if isinstance(files_payload, dict):
files = files_payload.get("files") or [] files = files_payload.get("files") or []
@@ -438,11 +436,14 @@ async def _reset_session(session_key: str, config: GatewayClientConfig) -> None:
async def _gateway_agent_files_index( async def _gateway_agent_files_index(
agent_id: str, config: GatewayClientConfig, agent_id: str,
config: GatewayClientConfig,
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
try: try:
payload = await openclaw_call( payload = await openclaw_call(
"agents.files.list", {"agentId": agent_id}, config=config, "agents.files.list",
{"agentId": agent_id},
config=config,
) )
if isinstance(payload, dict): if isinstance(payload, dict):
files = payload.get("files") or [] files = payload.get("files") or []
@@ -486,18 +487,14 @@ def _render_agent_files(
) )
heartbeat_path = _templates_root() / heartbeat_template heartbeat_path = _templates_root() / heartbeat_template
if heartbeat_path.exists(): if heartbeat_path.exists():
rendered[name] = ( rendered[name] = env.get_template(heartbeat_template).render(**context).strip()
env.get_template(heartbeat_template).render(**context).strip()
)
continue continue
override = overrides.get(name) override = overrides.get(name)
if override: if override:
rendered[name] = env.from_string(override).render(**context).strip() rendered[name] = env.from_string(override).render(**context).strip()
continue continue
template_name = ( template_name = (
template_overrides[name] template_overrides[name] if template_overrides and name in template_overrides else name
if template_overrides and name in template_overrides
else name
) )
path = _templates_root() / template_name path = _templates_root() / template_name
if path.exists(): if path.exists():
@@ -596,8 +593,7 @@ def _heartbeat_entry_map(
entries: list[tuple[str, str, dict[str, Any]]], entries: list[tuple[str, str, dict[str, Any]]],
) -> dict[str, tuple[str, dict[str, Any]]]: ) -> dict[str, tuple[str, dict[str, Any]]]:
return { return {
agent_id: (workspace_path, heartbeat) agent_id: (workspace_path, heartbeat) for agent_id, workspace_path, heartbeat in entries
for agent_id, workspace_path, heartbeat in entries
} }
@@ -694,9 +690,7 @@ async def _remove_gateway_agent_list(
raise OpenClawGatewayError(msg) raise OpenClawGatewayError(msg)
new_list = [ new_list = [
entry entry for entry in lst if not (isinstance(entry, dict) and entry.get("id") == agent_id)
for entry in lst
if not (isinstance(entry, dict) and entry.get("id") == agent_id)
] ]
if len(new_list) == len(lst): if len(new_list) == len(lst):
return return
@@ -841,7 +835,9 @@ async def provision_main_agent(
raise ValueError(msg) raise ValueError(msg)
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token) client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
await ensure_session( await ensure_session(
gateway.main_session_key, config=client_config, label="Main Agent", gateway.main_session_key,
config=client_config,
label="Main Agent",
) )
agent_id = await _gateway_default_agent_id( agent_id = await _gateway_default_agent_id(

View File

@@ -38,10 +38,7 @@ def _status_weight_expr() -> ColumnElement[int]:
def _priority_weight_expr() -> ColumnElement[int]: def _priority_weight_expr() -> ColumnElement[int]:
"""Return a SQL expression that sorts task priorities by configured order.""" """Return a SQL expression that sorts task priorities by configured order."""
whens = [ whens = [(col(Task.priority) == key, weight) for key, weight in _PRIORITY_ORDER.items()]
(col(Task.priority) == key, weight)
for key, weight in _PRIORITY_ORDER.items()
]
return case(*whens, else_=99) return case(*whens, else_=99)
@@ -106,11 +103,7 @@ async def _agent_names(
tasks: list[Task], tasks: list[Task],
) -> dict[UUID, str]: ) -> dict[UUID, str]:
"""Return agent names keyed by assigned agent ids in the provided tasks.""" """Return agent names keyed by assigned agent ids in the provided tasks."""
assigned_ids = { assigned_ids = {task.assigned_agent_id for task in tasks if task.assigned_agent_id is not None}
task.assigned_agent_id
for task in tasks
if task.assigned_agent_id is not None
}
if not assigned_ids: if not assigned_ids:
return {} return {}
return dict( return dict(

View File

@@ -10,11 +10,7 @@ from sqlmodel import col, select
from app.core.agent_tokens import generate_agent_token, hash_agent_token from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.time import utcnow from app.core.time import utcnow
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import ( from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
OpenClawGatewayError,
ensure_session,
send_message,
)
from app.models.agents import Agent from app.models.agents import Agent
from app.services.agent_provisioning import ( from app.services.agent_provisioning import (
DEFAULT_HEARTBEAT_CONFIG, DEFAULT_HEARTBEAT_CONFIG,

View File

@@ -55,8 +55,7 @@ def _agent_to_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
model = AgentRead.model_validate(agent, from_attributes=True) model = AgentRead.model_validate(agent, from_attributes=True)
computed_status = _computed_agent_status(agent) computed_status = _computed_agent_status(agent)
is_gateway_main = bool( is_gateway_main = bool(
agent.openclaw_session_id agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys,
and agent.openclaw_session_id in main_session_keys,
) )
return model.model_copy( return model.model_copy(
update={ update={
@@ -84,11 +83,7 @@ def _task_to_card(
) -> TaskCardRead: ) -> TaskCardRead:
card = TaskCardRead.model_validate(task, from_attributes=True) card = TaskCardRead.model_validate(task, from_attributes=True)
approvals_count, approvals_pending_count = counts_by_task_id.get(task.id, (0, 0)) approvals_count, approvals_pending_count = counts_by_task_id.get(task.id, (0, 0))
assignee = ( assignee = agent_name_by_id.get(task.assigned_agent_id) if task.assigned_agent_id else None
agent_name_by_id.get(task.assigned_agent_id)
if task.assigned_agent_id
else None
)
depends_on_task_ids = deps_by_task_id.get(task.id, []) depends_on_task_ids = deps_by_task_id.get(task.id, [])
blocked_by_task_ids = blocked_by_dependency_ids( blocked_by_task_ids = blocked_by_dependency_ids(
dependency_ids=depends_on_task_ids, dependency_ids=depends_on_task_ids,

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import re import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING: # pragma: no cover
from app.models.agents import Agent from app.models.agents import Agent
# Mention tokens are single, space-free words (e.g. "@alex", "@lead"). # Mention tokens are single, space-free words (e.g. "@alex", "@lead").

View File

@@ -79,7 +79,8 @@ async def get_member(
async def get_first_membership( async def get_first_membership(
session: AsyncSession, user_id: UUID, session: AsyncSession,
user_id: UUID,
) -> OrganizationMember | None: ) -> OrganizationMember | None:
"""Return the oldest membership for a user, if any.""" """Return the oldest membership for a user, if any."""
return ( return (
@@ -99,7 +100,8 @@ async def set_active_organization(
member = await get_member(session, user_id=user.id, organization_id=organization_id) member = await get_member(session, user_id=user.id, organization_id=organization_id)
if member is None: if member is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No org access", status_code=status.HTTP_403_FORBIDDEN,
detail="No org access",
) )
if user.active_organization_id != organization_id: if user.active_organization_id != organization_id:
user.active_organization_id = organization_id user.active_organization_id = organization_id
@@ -177,8 +179,7 @@ async def accept_invite(
access_rows = list( access_rows = list(
await session.exec( await session.exec(
select(OrganizationInviteBoardAccess).where( select(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id) col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
== invite.id,
), ),
), ),
) )
@@ -207,7 +208,8 @@ async def accept_invite(
async def ensure_member_for_user( async def ensure_member_for_user(
session: AsyncSession, user: User, session: AsyncSession,
user: User,
) -> OrganizationMember: ) -> OrganizationMember:
"""Ensure a user has some membership, creating one if necessary.""" """Ensure a user has some membership, creating one if necessary."""
existing = await get_active_membership(session, user) existing = await get_active_membership(session, user)
@@ -291,21 +293,27 @@ async def require_board_access(
) -> OrganizationMember: ) -> OrganizationMember:
"""Require board access for a user and return matching membership.""" """Require board access for a user and return matching membership."""
member = await get_member( member = await get_member(
session, user_id=user.id, organization_id=board.organization_id, session,
user_id=user.id,
organization_id=board.organization_id,
) )
if member is None: if member is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No org access", status_code=status.HTTP_403_FORBIDDEN,
detail="No org access",
) )
if not await has_board_access(session, member=member, board=board, write=write): if not await has_board_access(session, member=member, board=board, write=write):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied", status_code=status.HTTP_403_FORBIDDEN,
detail="Board access denied",
) )
return member return member
def board_access_filter( def board_access_filter(
member: OrganizationMember, *, write: bool, member: OrganizationMember,
*,
write: bool,
) -> ColumnElement[bool]: ) -> ColumnElement[bool]:
"""Build a SQL filter expression for boards visible to a member.""" """Build a SQL filter expression for boards visible to a member."""
if write and member_all_boards_write(member): if write and member_all_boards_write(member):

View File

@@ -42,10 +42,7 @@ class SoulRef:
def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]: def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
"""Parse sitemap XML and extract valid souls.directory handle/slug refs.""" """Parse sitemap XML and extract valid souls.directory handle/slug refs."""
# Extract <loc> values without XML entity expansion. # Extract <loc> values without XML entity expansion.
urls = [ urls = [unescape(match.group(1)).strip() for match in _LOC_PATTERN.finditer(sitemap_xml)]
unescape(match.group(1)).strip()
for match in _LOC_PATTERN.finditer(sitemap_xml)
]
refs: list[SoulRef] = [] refs: list[SoulRef] = []
for url in urls: for url in urls:
@@ -110,10 +107,7 @@ async def fetch_soul_markdown(
normalized_slug = slug.strip().strip("/") normalized_slug = slug.strip().strip("/")
if normalized_slug.endswith(".md"): if normalized_slug.endswith(".md"):
normalized_slug = normalized_slug[: -len(".md")] normalized_slug = normalized_slug[: -len(".md")]
url = ( url = f"{SOULS_DIRECTORY_BASE_URL}/api/souls/" f"{normalized_handle}/{normalized_slug}.md"
f"{SOULS_DIRECTORY_BASE_URL}/api/souls/"
f"{normalized_handle}/{normalized_slug}.md"
)
owns_client = client is None owns_client = client is None
if client is None: if client is None:

View File

@@ -79,11 +79,7 @@ def blocked_by_dependency_ids(
status_by_id: Mapping[UUID, str], status_by_id: Mapping[UUID, str],
) -> list[UUID]: ) -> list[UUID]:
"""Return dependency ids that are not yet in the done status.""" """Return dependency ids that are not yet in the done status."""
return [ return [dep_id for dep_id in dependency_ids if status_by_id.get(dep_id) != DONE_STATUS]
dep_id
for dep_id in dependency_ids
if status_by_id.get(dep_id) != DONE_STATUS
]
async def blocked_by_for_task( async def blocked_by_for_task(

View File

@@ -14,11 +14,7 @@ from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.agent_tokens import ( from app.core.agent_tokens import generate_agent_token, hash_agent_token, verify_agent_token
generate_agent_token,
hash_agent_token,
verify_agent_token,
)
from app.core.time import utcnow from app.core.time import utcnow
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, openclaw_call from app.integrations.openclaw_gateway import OpenClawGatewayError, openclaw_call
@@ -108,10 +104,7 @@ def _is_transient_gateway_error(exc: Exception) -> bool:
def _gateway_timeout_message(exc: OpenClawGatewayError) -> str: def _gateway_timeout_message(exc: OpenClawGatewayError) -> str:
return ( return "Gateway unreachable after 10 minutes (template sync timeout). " f"Last error: {exc}"
"Gateway unreachable after 10 minutes (template sync timeout). "
f"Last error: {exc}"
)
class _GatewayBackoff: class _GatewayBackoff:
@@ -375,6 +368,7 @@ async def _rotate_agent_token(session: AsyncSession, agent: Agent) -> str:
async def _ping_gateway(ctx: _SyncContext, result: GatewayTemplatesSyncResult) -> bool: async def _ping_gateway(ctx: _SyncContext, result: GatewayTemplatesSyncResult) -> bool:
try: try:
async def _do_ping() -> object: async def _do_ping() -> object:
return await openclaw_call("agents.list", config=ctx.config) return await openclaw_call("agents.list", config=ctx.config)
@@ -486,6 +480,7 @@ async def _sync_one_agent(
if not auth_token: if not auth_token:
return False return False
try: try:
async def _do_provision() -> None: async def _do_provision() -> None:
await provision_agent( await provision_agent(
agent, agent,
@@ -533,10 +528,7 @@ async def _sync_main_agent(
if main_agent is None: if main_agent is None:
_append_sync_error( _append_sync_error(
result, result,
message=( message=("Gateway main agent record not found; " "skipping main agent template sync."),
"Gateway main agent record not found; "
"skipping main agent template sync."
),
) )
return True return True
try: try:
@@ -574,6 +566,7 @@ async def _sync_main_agent(
return True return True
stop_sync = False stop_sync = False
try: try:
async def _do_provision_main() -> None: async def _do_provision_main() -> None:
await provision_main_agent( await provision_main_agent(
main_agent, main_agent,

View File

@@ -32,17 +32,13 @@ def _parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--reset-sessions", "--reset-sessions",
action="store_true", action="store_true",
help=( help=("Reset agent sessions after syncing files " "(forces agents to re-read workspace)"),
"Reset agent sessions after syncing files "
"(forces agents to re-read workspace)"
),
) )
parser.add_argument( parser.add_argument(
"--rotate-tokens", "--rotate-tokens",
action="store_true", action="store_true",
help=( help=(
"Rotate agent tokens when TOOLS.md is missing/unreadable " "Rotate agent tokens when TOOLS.md is missing/unreadable " "or token drift is detected"
"or token drift is detected"
), ),
) )
parser.add_argument( parser.add_argument(
@@ -56,10 +52,7 @@ def _parse_args() -> argparse.Namespace:
async def _run() -> int: async def _run() -> int:
from app.db.session import async_session_maker from app.db.session import async_session_maker
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.services.template_sync import ( from app.services.template_sync import GatewayTemplateSyncOptions, sync_gateway_templates
GatewayTemplateSyncOptions,
sync_gateway_templates,
)
args = _parse_args() args = _parse_args()
gateway_id = UUID(args.gateway_id) gateway_id = UUID(args.gateway_id)
@@ -86,8 +79,7 @@ async def _run() -> int:
sys.stdout.write(f"gateway_id={result.gateway_id}\n") sys.stdout.write(f"gateway_id={result.gateway_id}\n")
sys.stdout.write( sys.stdout.write(
f"include_main={result.include_main} " f"include_main={result.include_main} " f"reset_sessions={result.reset_sessions}\n",
f"reset_sessions={result.reset_sessions}\n",
) )
sys.stdout.write( sys.stdout.write(
f"agents_updated={result.agents_updated} " f"agents_updated={result.agents_updated} "

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import pytest
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -11,6 +12,9 @@ from app.core.error_handling import (
REQUEST_ID_HEADER, REQUEST_ID_HEADER,
_error_payload, _error_payload,
_get_request_id, _get_request_id,
_http_exception_exception_handler,
_request_validation_exception_handler,
_response_validation_exception_handler,
install_error_handling, install_error_handling,
) )
@@ -123,3 +127,24 @@ def test_get_request_id_returns_none_for_missing_or_invalid_state() -> None:
def test_error_payload_omits_request_id_when_none() -> None: def test_error_payload_omits_request_id_when_none() -> None:
assert _error_payload(detail="x", request_id=None) == {"detail": "x"} assert _error_payload(detail="x", request_id=None) == {"detail": "x"}
@pytest.mark.asyncio
async def test_request_validation_exception_wrapper_rejects_wrong_exception() -> None:
req = Request({"type": "http", "headers": [], "state": {}})
with pytest.raises(TypeError, match="Expected RequestValidationError"):
await _request_validation_exception_handler(req, Exception("x"))
@pytest.mark.asyncio
async def test_response_validation_exception_wrapper_rejects_wrong_exception() -> None:
req = Request({"type": "http", "headers": [], "state": {}})
with pytest.raises(TypeError, match="Expected ResponseValidationError"):
await _response_validation_exception_handler(req, Exception("x"))
@pytest.mark.asyncio
async def test_http_exception_wrapper_rejects_wrong_exception() -> None:
req = Request({"type": "http", "headers": [], "state": {}})
with pytest.raises(TypeError, match="Expected StarletteHTTPException"):
await _http_exception_exception_handler(req, Exception("x"))

View File

@@ -16,10 +16,7 @@ from app.models.organization_invites import OrganizationInvite
from app.models.organization_members import OrganizationMember from app.models.organization_members import OrganizationMember
from app.models.organizations import Organization from app.models.organizations import Organization
from app.models.users import User from app.models.users import User
from app.schemas.organizations import ( from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate
OrganizationBoardAccessSpec,
OrganizationMemberAccessUpdate,
)
from app.services import organizations from app.services import organizations
@@ -87,10 +84,7 @@ class _FakeSession:
def test_normalize_invited_email_strips_and_lowercases() -> None: def test_normalize_invited_email_strips_and_lowercases() -> None:
assert ( assert organizations.normalize_invited_email(" Foo@Example.com ") == "foo@example.com"
organizations.normalize_invited_email(" Foo@Example.com ")
== "foo@example.com"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -128,7 +122,9 @@ async def test_ensure_member_for_user_returns_existing_membership(
) -> None: ) -> None:
user = User(clerk_user_id="u1") user = User(clerk_user_id="u1")
existing = OrganizationMember( existing = OrganizationMember(
organization_id=uuid4(), user_id=user.id, role="member", organization_id=uuid4(),
user_id=user.id,
role="member",
) )
async def _fake_get_active(_session: Any, _user: User) -> OrganizationMember: async def _fake_get_active(_session: Any, _user: User) -> OrganizationMember:
@@ -161,11 +157,15 @@ async def test_ensure_member_for_user_accepts_pending_invite(
return invite return invite
accepted = OrganizationMember( accepted = OrganizationMember(
organization_id=org_id, user_id=user.id, role="member", organization_id=org_id,
user_id=user.id,
role="member",
) )
async def _fake_accept( async def _fake_accept(
_session: Any, _invite: OrganizationInvite, _user: User, _session: Any,
_invite: OrganizationInvite,
_user: User,
) -> OrganizationMember: ) -> OrganizationMember:
assert _invite is invite assert _invite is invite
assert _user is user assert _user is user
@@ -216,7 +216,10 @@ async def test_has_board_access_denies_cross_org() -> None:
board = Board(id=uuid4(), organization_id=uuid4(), name="b", slug="b") board = Board(id=uuid4(), organization_id=uuid4(), name="b", slug="b")
assert ( assert (
await organizations.has_board_access( await organizations.has_board_access(
session, member=member, board=board, write=False, session,
member=member,
board=board,
write=False,
) )
is False is False
) )
@@ -226,7 +229,10 @@ async def test_has_board_access_denies_cross_org() -> None:
async def test_has_board_access_uses_org_board_access_row_read_and_write() -> None: async def test_has_board_access_uses_org_board_access_row_read_and_write() -> None:
org_id = uuid4() org_id = uuid4()
member = OrganizationMember( member = OrganizationMember(
id=uuid4(), organization_id=org_id, user_id=uuid4(), role="member", id=uuid4(),
organization_id=org_id,
user_id=uuid4(),
role="member",
) )
board = Board(id=uuid4(), organization_id=org_id, name="b", slug="b") board = Board(id=uuid4(), organization_id=org_id, name="b", slug="b")
@@ -239,7 +245,10 @@ async def test_has_board_access_uses_org_board_access_row_read_and_write() -> No
session = _FakeSession(exec_results=[_FakeExecResult(first_value=access)]) session = _FakeSession(exec_results=[_FakeExecResult(first_value=access)])
assert ( assert (
await organizations.has_board_access( await organizations.has_board_access(
session, member=member, board=board, write=False, session,
member=member,
board=board,
write=False,
) )
is True is True
) )
@@ -253,7 +262,10 @@ async def test_has_board_access_uses_org_board_access_row_read_and_write() -> No
session2 = _FakeSession(exec_results=[_FakeExecResult(first_value=access2)]) session2 = _FakeSession(exec_results=[_FakeExecResult(first_value=access2)])
assert ( assert (
await organizations.has_board_access( await organizations.has_board_access(
session2, member=member, board=board, write=False, session2,
member=member,
board=board,
write=False,
) )
is True is True
) )
@@ -267,7 +279,10 @@ async def test_has_board_access_uses_org_board_access_row_read_and_write() -> No
session3 = _FakeSession(exec_results=[_FakeExecResult(first_value=access3)]) session3 = _FakeSession(exec_results=[_FakeExecResult(first_value=access3)])
assert ( assert (
await organizations.has_board_access( await organizations.has_board_access(
session3, member=member, board=board, write=True, session3,
member=member,
board=board,
write=True,
) )
is False is False
) )
@@ -288,7 +303,10 @@ async def test_require_board_access_raises_when_no_member(
session = _FakeSession(exec_results=[]) session = _FakeSession(exec_results=[])
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
await organizations.require_board_access( await organizations.require_board_access(
session, user=user, board=board, write=False, session,
user=user,
board=board,
write=False,
) )
assert exc.value.status_code == 403 assert exc.value.status_code == 403
@@ -298,24 +316,33 @@ async def test_apply_member_access_update_deletes_existing_and_adds_rows_when_no
None None
): ):
member = OrganizationMember( member = OrganizationMember(
id=uuid4(), organization_id=uuid4(), user_id=uuid4(), role="member", id=uuid4(),
organization_id=uuid4(),
user_id=uuid4(),
role="member",
) )
update = OrganizationMemberAccessUpdate( update = OrganizationMemberAccessUpdate(
all_boards_read=False, all_boards_read=False,
all_boards_write=False, all_boards_write=False,
board_access=[ board_access=[
OrganizationBoardAccessSpec( OrganizationBoardAccessSpec(
board_id=uuid4(), can_read=True, can_write=False, board_id=uuid4(),
can_read=True,
can_write=False,
), ),
OrganizationBoardAccessSpec( OrganizationBoardAccessSpec(
board_id=uuid4(), can_read=True, can_write=True, board_id=uuid4(),
can_read=True,
can_write=True,
), ),
], ],
) )
session = _FakeSession(exec_results=[]) session = _FakeSession(exec_results=[])
await organizations.apply_member_access_update( await organizations.apply_member_access_update(
session, member=member, update=update, session,
member=member,
update=update,
) )
# delete statement executed once # delete statement executed once