refactor: simplify code formatting and improve readability across multiple files
This commit is contained in:
@@ -22,10 +22,7 @@ from app.models.activity_events import ActivityEvent
|
|||||||
from app.models.agents import Agent
|
from app.models.agents import Agent
|
||||||
from app.models.boards import Board
|
from app.models.boards import Board
|
||||||
from app.models.tasks import Task
|
from app.models.tasks import Task
|
||||||
from app.schemas.activity_events import (
|
from app.schemas.activity_events import ActivityEventRead, ActivityTaskCommentFeedItemRead
|
||||||
ActivityEventRead,
|
|
||||||
ActivityTaskCommentFeedItemRead,
|
|
||||||
)
|
|
||||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||||
from app.services.organizations import (
|
from app.services.organizations import (
|
||||||
OrganizationContext,
|
OrganizationContext,
|
||||||
@@ -198,10 +195,7 @@ async def list_task_comment_feed(
|
|||||||
|
|
||||||
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
||||||
rows = _coerce_task_comment_rows(items)
|
rows = _coerce_task_comment_rows(items)
|
||||||
return [
|
return [_feed_item(event, task, board, agent) for event, task, board, agent in rows]
|
||||||
_feed_item(event, task, board, agent)
|
|
||||||
for event, task, board, agent in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
return await paginate(session, statement, transformer=_transform)
|
return await paginate(session, statement, transformer=_transform)
|
||||||
|
|
||||||
|
|||||||
@@ -53,19 +53,9 @@ from app.schemas.gateway_coordination import (
|
|||||||
GatewayMainAskUserResponse,
|
GatewayMainAskUserResponse,
|
||||||
)
|
)
|
||||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||||
from app.schemas.tasks import (
|
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
|
||||||
TaskCommentCreate,
|
|
||||||
TaskCommentRead,
|
|
||||||
TaskCreate,
|
|
||||||
TaskRead,
|
|
||||||
TaskUpdate,
|
|
||||||
)
|
|
||||||
from app.services.activity_log import record_activity
|
from app.services.activity_log import record_activity
|
||||||
from app.services.board_leads import (
|
from app.services.board_leads import LeadAgentOptions, LeadAgentRequest, ensure_board_lead_agent
|
||||||
LeadAgentOptions,
|
|
||||||
LeadAgentRequest,
|
|
||||||
ensure_board_lead_agent,
|
|
||||||
)
|
|
||||||
from app.services.task_dependencies import (
|
from app.services.task_dependencies import (
|
||||||
blocked_by_dependency_ids,
|
blocked_by_dependency_ids,
|
||||||
dependency_status_by_id,
|
dependency_status_by_id,
|
||||||
@@ -212,7 +202,8 @@ async def _require_gateway_board(
|
|||||||
board = await Board.objects.by_id(board_id).first(session)
|
board = await Board.objects.by_id(board_id).first(session)
|
||||||
if board is None:
|
if board is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found",
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Board not found",
|
||||||
)
|
)
|
||||||
if board.gateway_id != gateway.id:
|
if board.gateway_id != gateway.id:
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||||
@@ -322,7 +313,8 @@ async def create_task(
|
|||||||
dependency_ids=normalized_deps,
|
dependency_ids=normalized_deps,
|
||||||
)
|
)
|
||||||
blocked_by = blocked_by_dependency_ids(
|
blocked_by = blocked_by_dependency_ids(
|
||||||
dependency_ids=normalized_deps, status_by_id=dep_status,
|
dependency_ids=normalized_deps,
|
||||||
|
status_by_id=dep_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
|
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
|
||||||
@@ -393,11 +385,7 @@ async def update_task(
|
|||||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||||
) -> TaskRead:
|
) -> TaskRead:
|
||||||
"""Update a task after board-level access checks."""
|
"""Update a task after board-level access checks."""
|
||||||
if (
|
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
|
||||||
agent_ctx.agent.board_id
|
|
||||||
and task.board_id
|
|
||||||
and agent_ctx.agent.board_id != task.board_id
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||||
return await tasks_api.update_task(
|
return await tasks_api.update_task(
|
||||||
payload=payload,
|
payload=payload,
|
||||||
@@ -417,11 +405,7 @@ async def list_task_comments(
|
|||||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||||
) -> LimitOffsetPage[TaskCommentRead]:
|
) -> LimitOffsetPage[TaskCommentRead]:
|
||||||
"""List comments for a task visible to the authenticated agent."""
|
"""List comments for a task visible to the authenticated agent."""
|
||||||
if (
|
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
|
||||||
agent_ctx.agent.board_id
|
|
||||||
and task.board_id
|
|
||||||
and agent_ctx.agent.board_id != task.board_id
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||||
return await tasks_api.list_task_comments(
|
return await tasks_api.list_task_comments(
|
||||||
task=task,
|
task=task,
|
||||||
@@ -430,7 +414,8 @@ async def list_task_comments(
|
|||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/boards/{board_id}/tasks/{task_id}/comments", response_model=TaskCommentRead,
|
"/boards/{board_id}/tasks/{task_id}/comments",
|
||||||
|
response_model=TaskCommentRead,
|
||||||
)
|
)
|
||||||
async def create_task_comment(
|
async def create_task_comment(
|
||||||
payload: TaskCommentCreate,
|
payload: TaskCommentCreate,
|
||||||
@@ -439,11 +424,7 @@ async def create_task_comment(
|
|||||||
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
agent_ctx: AgentAuthContext = AGENT_CTX_DEP,
|
||||||
) -> ActivityEvent:
|
) -> ActivityEvent:
|
||||||
"""Create a task comment on behalf of the authenticated agent."""
|
"""Create a task comment on behalf of the authenticated agent."""
|
||||||
if (
|
if agent_ctx.agent.board_id and task.board_id and agent_ctx.agent.board_id != task.board_id:
|
||||||
agent_ctx.agent.board_id
|
|
||||||
and task.board_id
|
|
||||||
and agent_ctx.agent.board_id != task.board_id
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||||
return await tasks_api.create_task_comment(
|
return await tasks_api.create_task_comment(
|
||||||
payload=payload,
|
payload=payload,
|
||||||
@@ -454,7 +435,8 @@ async def create_task_comment(
|
|||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/boards/{board_id}/memory", response_model=DefaultLimitOffsetPage[BoardMemoryRead],
|
"/boards/{board_id}/memory",
|
||||||
|
response_model=DefaultLimitOffsetPage[BoardMemoryRead],
|
||||||
)
|
)
|
||||||
async def list_board_memory(
|
async def list_board_memory(
|
||||||
is_chat: bool | None = IS_CHAT_QUERY,
|
is_chat: bool | None = IS_CHAT_QUERY,
|
||||||
@@ -588,7 +570,9 @@ async def nudge_agent(
|
|||||||
config = await _gateway_config(session, board)
|
config = await _gateway_config(session, board)
|
||||||
try:
|
try:
|
||||||
await ensure_session(
|
await ensure_session(
|
||||||
target.openclaw_session_id, config=config, label=target.name,
|
target.openclaw_session_id,
|
||||||
|
config=config,
|
||||||
|
label=target.name,
|
||||||
)
|
)
|
||||||
await send_message(
|
await send_message(
|
||||||
message,
|
message,
|
||||||
@@ -605,7 +589,8 @@ async def nudge_agent(
|
|||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
record_activity(
|
record_activity(
|
||||||
session,
|
session,
|
||||||
@@ -657,7 +642,8 @@ async def get_agent_soul(
|
|||||||
)
|
)
|
||||||
except OpenClawGatewayError as exc:
|
except OpenClawGatewayError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
if isinstance(payload, str):
|
if isinstance(payload, str):
|
||||||
return payload
|
return payload
|
||||||
@@ -671,7 +657,8 @@ async def get_agent_soul(
|
|||||||
if isinstance(nested, str):
|
if isinstance(nested, str):
|
||||||
return nested
|
return nested
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail="Invalid gateway response",
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail="Invalid gateway response",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -712,7 +699,8 @@ async def update_agent_soul(
|
|||||||
)
|
)
|
||||||
except OpenClawGatewayError as exc:
|
except OpenClawGatewayError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
reason = (payload.reason or "").strip()
|
reason = (payload.reason or "").strip()
|
||||||
source_url = (payload.source_url or "").strip()
|
source_url = (payload.source_url or "").strip()
|
||||||
@@ -770,9 +758,7 @@ async def ask_user_via_gateway_main(
|
|||||||
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
|
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
|
||||||
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
|
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
|
||||||
preferred_channel = (payload.preferred_channel or "").strip()
|
preferred_channel = (payload.preferred_channel or "").strip()
|
||||||
channel_line = (
|
channel_line = f"Preferred channel: {preferred_channel}\n" if preferred_channel else ""
|
||||||
f"Preferred channel: {preferred_channel}\n" if preferred_channel else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
tags = payload.reply_tags or ["gateway_main", "user_reply"]
|
tags = payload.reply_tags or ["gateway_main", "user_reply"]
|
||||||
tags_json = json.dumps(tags)
|
tags_json = json.dumps(tags)
|
||||||
@@ -801,7 +787,10 @@ async def ask_user_via_gateway_main(
|
|||||||
try:
|
try:
|
||||||
await ensure_session(main_session_key, config=config, label="Main Agent")
|
await ensure_session(main_session_key, config=config, label="Main Agent")
|
||||||
await send_message(
|
await send_message(
|
||||||
message, session_key=main_session_key, config=config, deliver=True,
|
message,
|
||||||
|
session_key=main_session_key,
|
||||||
|
config=config,
|
||||||
|
deliver=True,
|
||||||
)
|
)
|
||||||
except OpenClawGatewayError as exc:
|
except OpenClawGatewayError as exc:
|
||||||
record_activity(
|
record_activity(
|
||||||
@@ -812,7 +801,8 @@ async def ask_user_via_gateway_main(
|
|||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
record_activity(
|
record_activity(
|
||||||
@@ -867,11 +857,7 @@ async def message_gateway_board_lead(
|
|||||||
)
|
)
|
||||||
|
|
||||||
base_url = settings.base_url or "http://localhost:8000"
|
base_url = settings.base_url or "http://localhost:8000"
|
||||||
header = (
|
header = "GATEWAY MAIN QUESTION" if payload.kind == "question" else "GATEWAY MAIN HANDOFF"
|
||||||
"GATEWAY MAIN QUESTION"
|
|
||||||
if payload.kind == "question"
|
|
||||||
else "GATEWAY MAIN HANDOFF"
|
|
||||||
)
|
|
||||||
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
|
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
|
||||||
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
|
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
|
||||||
tags = payload.reply_tags or ["gateway_main", "lead_reply"]
|
tags = payload.reply_tags or ["gateway_main", "lead_reply"]
|
||||||
@@ -903,7 +889,8 @@ async def message_gateway_board_lead(
|
|||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
record_activity(
|
record_activity(
|
||||||
@@ -946,11 +933,7 @@ async def broadcast_gateway_lead_message(
|
|||||||
boards = list(await session.exec(statement))
|
boards = list(await session.exec(statement))
|
||||||
|
|
||||||
base_url = settings.base_url or "http://localhost:8000"
|
base_url = settings.base_url or "http://localhost:8000"
|
||||||
header = (
|
header = "GATEWAY MAIN QUESTION" if payload.kind == "question" else "GATEWAY MAIN HANDOFF"
|
||||||
"GATEWAY MAIN QUESTION"
|
|
||||||
if payload.kind == "question"
|
|
||||||
else "GATEWAY MAIN HANDOFF"
|
|
||||||
)
|
|
||||||
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
|
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
|
||||||
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
|
correlation_line = f"Correlation ID: {correlation}\n" if correlation else ""
|
||||||
tags = payload.reply_tags or ["gateway_main", "lead_reply"]
|
tags = payload.reply_tags or ["gateway_main", "lead_reply"]
|
||||||
|
|||||||
@@ -23,11 +23,7 @@ from app.db import crud
|
|||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
from app.db.session import async_session_maker, get_session
|
from app.db.session import async_session_maker, get_session
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import (
|
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||||
OpenClawGatewayError,
|
|
||||||
ensure_session,
|
|
||||||
send_message,
|
|
||||||
)
|
|
||||||
from app.models.activity_events import ActivityEvent
|
from app.models.activity_events import ActivityEvent
|
||||||
from app.models.agents import Agent
|
from app.models.agents import Agent
|
||||||
from app.models.boards import Board
|
from app.models.boards import Board
|
||||||
@@ -154,7 +150,8 @@ async def _require_board(
|
|||||||
board = await Board.objects.by_id(board_id).first(session)
|
board = await Board.objects.by_id(board_id).first(session)
|
||||||
if board is None:
|
if board is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found",
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Board not found",
|
||||||
)
|
)
|
||||||
if user is not None:
|
if user is not None:
|
||||||
await require_board_access(session, user=user, board=board, write=write)
|
await require_board_access(session, user=user, board=board, write=write)
|
||||||
@@ -162,7 +159,8 @@ async def _require_board(
|
|||||||
|
|
||||||
|
|
||||||
async def _require_gateway(
|
async def _require_gateway(
|
||||||
session: AsyncSession, board: Board,
|
session: AsyncSession,
|
||||||
|
board: Board,
|
||||||
) -> tuple[Gateway, GatewayClientConfig]:
|
) -> tuple[Gateway, GatewayClientConfig]:
|
||||||
if not board.gateway_id:
|
if not board.gateway_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -246,7 +244,8 @@ def _coerce_agent_items(items: Sequence[Any]) -> list[Agent]:
|
|||||||
|
|
||||||
|
|
||||||
async def _find_gateway_for_main_session(
|
async def _find_gateway_for_main_session(
|
||||||
session: AsyncSession, session_key: str | None,
|
session: AsyncSession,
|
||||||
|
session_key: str | None,
|
||||||
) -> Gateway | None:
|
) -> Gateway | None:
|
||||||
if not session_key:
|
if not session_key:
|
||||||
return None
|
return None
|
||||||
@@ -306,7 +305,8 @@ async def _fetch_agent_events(
|
|||||||
|
|
||||||
|
|
||||||
async def _require_user_context(
|
async def _require_user_context(
|
||||||
session: AsyncSession, user: User | None,
|
session: AsyncSession,
|
||||||
|
user: User | None,
|
||||||
) -> OrganizationContext:
|
) -> OrganizationContext:
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||||
@@ -332,7 +332,8 @@ async def _require_agent_access(
|
|||||||
if not is_org_admin(ctx.member):
|
if not is_org_admin(ctx.member):
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||||
gateway = await _find_gateway_for_main_session(
|
gateway = await _find_gateway_for_main_session(
|
||||||
session, agent.openclaw_session_id,
|
session,
|
||||||
|
agent.openclaw_session_id,
|
||||||
)
|
)
|
||||||
if gateway is None or gateway.organization_id != ctx.organization.id:
|
if gateway is None or gateway.organization_id != ctx.organization.id:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||||
@@ -355,7 +356,10 @@ def _record_heartbeat(session: AsyncSession, agent: Agent) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _record_instruction_failure(
|
def _record_instruction_failure(
|
||||||
session: AsyncSession, agent: Agent, error: str, action: str,
|
session: AsyncSession,
|
||||||
|
agent: Agent,
|
||||||
|
error: str,
|
||||||
|
action: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
action_label = action.replace("_", " ").capitalize()
|
action_label = action.replace("_", " ").capitalize()
|
||||||
record_activity(
|
record_activity(
|
||||||
@@ -432,10 +436,7 @@ async def _ensure_unique_agent_name(
|
|||||||
if existing_gateway:
|
if existing_gateway:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
detail=(
|
detail=("An agent with this name already exists in this gateway " "workspace."),
|
||||||
"An agent with this name already exists in this gateway "
|
|
||||||
"workspace."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
desired_session_key = _build_session_key(requested_name)
|
desired_session_key = _build_session_key(requested_name)
|
||||||
@@ -938,7 +939,9 @@ async def _commit_heartbeat(
|
|||||||
|
|
||||||
|
|
||||||
async def _send_wakeup_message(
|
async def _send_wakeup_message(
|
||||||
agent: Agent, config: GatewayClientConfig, verb: str = "provisioned",
|
agent: Agent,
|
||||||
|
config: GatewayClientConfig,
|
||||||
|
verb: str = "provisioned",
|
||||||
) -> None:
|
) -> None:
|
||||||
session_key = agent.openclaw_session_id or _build_session_key(agent.name)
|
session_key = agent.openclaw_session_id or _build_session_key(agent.name)
|
||||||
await ensure_session(session_key, config=config, label=agent.name)
|
await ensure_session(session_key, config=config, label=agent.name)
|
||||||
@@ -971,7 +974,8 @@ async def list_agents(
|
|||||||
col(Gateway.organization_id) == ctx.organization.id,
|
col(Gateway.organization_id) == ctx.organization.id,
|
||||||
)
|
)
|
||||||
base_filter = or_(
|
base_filter = or_(
|
||||||
base_filter, col(Agent.openclaw_session_id).in_(gateway_keys),
|
base_filter,
|
||||||
|
col(Agent.openclaw_session_id).in_(gateway_keys),
|
||||||
)
|
)
|
||||||
statement = select(Agent).where(base_filter)
|
statement = select(Agent).where(base_filter)
|
||||||
if board_id is not None:
|
if board_id is not None:
|
||||||
@@ -987,10 +991,7 @@ async def list_agents(
|
|||||||
|
|
||||||
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
def _transform(items: Sequence[Any]) -> Sequence[Any]:
|
||||||
agents = _coerce_agent_items(items)
|
agents = _coerce_agent_items(items)
|
||||||
return [
|
return [_to_agent_read(_with_computed_status(agent), main_session_keys) for agent in agents]
|
||||||
_to_agent_read(_with_computed_status(agent), main_session_keys)
|
|
||||||
for agent in agents
|
|
||||||
]
|
|
||||||
|
|
||||||
return await paginate(session, statement, transformer=_transform)
|
return await paginate(session, statement, transformer=_transform)
|
||||||
|
|
||||||
@@ -1019,19 +1020,17 @@ async def stream_agents(
|
|||||||
async with async_session_maker() as stream_session:
|
async with async_session_maker() as stream_session:
|
||||||
if board_id is not None:
|
if board_id is not None:
|
||||||
agents = await _fetch_agent_events(
|
agents = await _fetch_agent_events(
|
||||||
stream_session, board_id, last_seen,
|
stream_session,
|
||||||
|
board_id,
|
||||||
|
last_seen,
|
||||||
)
|
)
|
||||||
elif allowed_ids:
|
elif allowed_ids:
|
||||||
agents = await _fetch_agent_events(stream_session, None, last_seen)
|
agents = await _fetch_agent_events(stream_session, None, last_seen)
|
||||||
agents = [
|
agents = [agent for agent in agents if agent.board_id in allowed_ids]
|
||||||
agent for agent in agents if agent.board_id in allowed_ids
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
agents = []
|
agents = []
|
||||||
main_session_keys = (
|
main_session_keys = (
|
||||||
await _get_gateway_main_session_keys(stream_session)
|
await _get_gateway_main_session_keys(stream_session) if agents else set()
|
||||||
if agents
|
|
||||||
else set()
|
|
||||||
)
|
)
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
updated_at = agent.updated_at or agent.last_seen_at or utcnow()
|
updated_at = agent.updated_at or agent.last_seen_at or utcnow()
|
||||||
@@ -1252,7 +1251,8 @@ async def delete_agent(
|
|||||||
await _require_agent_access(session, agent=agent, ctx=ctx, write=True)
|
await _require_agent_access(session, agent=agent, ctx=ctx, write=True)
|
||||||
|
|
||||||
board = await _require_board(
|
board = await _require_board(
|
||||||
session, str(agent.board_id) if agent.board_id else None,
|
session,
|
||||||
|
str(agent.board_id) if agent.board_id else None,
|
||||||
)
|
)
|
||||||
gateway, client_config = await _require_gateway(session, board)
|
gateway, client_config = await _require_gateway(session, board)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -24,12 +24,7 @@ from app.core.time import utcnow
|
|||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
from app.db.session import async_session_maker, get_session
|
from app.db.session import async_session_maker, get_session
|
||||||
from app.models.approvals import Approval
|
from app.models.approvals import Approval
|
||||||
from app.schemas.approvals import (
|
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus, ApprovalUpdate
|
||||||
ApprovalCreate,
|
|
||||||
ApprovalRead,
|
|
||||||
ApprovalStatus,
|
|
||||||
ApprovalUpdate,
|
|
||||||
)
|
|
||||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -156,9 +151,7 @@ async def stream_approvals(
|
|||||||
).one(),
|
).one(),
|
||||||
)
|
)
|
||||||
task_ids = {
|
task_ids = {
|
||||||
approval.task_id
|
approval.task_id for approval in approvals if approval.task_id is not None
|
||||||
for approval in approvals
|
|
||||||
if approval.task_id is not None
|
|
||||||
}
|
}
|
||||||
counts_by_task_id: dict[UUID, tuple[int, int]] = {}
|
counts_by_task_id: dict[UUID, tuple[int, int]] = {}
|
||||||
if task_ids:
|
if task_ids:
|
||||||
|
|||||||
@@ -26,21 +26,14 @@ from app.core.time import utcnow
|
|||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
from app.db.session import async_session_maker, get_session
|
from app.db.session import async_session_maker, get_session
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import (
|
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||||
OpenClawGatewayError,
|
|
||||||
ensure_session,
|
|
||||||
send_message,
|
|
||||||
)
|
|
||||||
from app.models.agents import Agent
|
from app.models.agents import Agent
|
||||||
from app.models.board_group_memory import BoardGroupMemory
|
from app.models.board_group_memory import BoardGroupMemory
|
||||||
from app.models.board_groups import BoardGroup
|
from app.models.board_groups import BoardGroup
|
||||||
from app.models.boards import Board
|
from app.models.boards import Board
|
||||||
from app.models.gateways import Gateway
|
from app.models.gateways import Gateway
|
||||||
from app.models.users import User
|
from app.models.users import User
|
||||||
from app.schemas.board_group_memory import (
|
from app.schemas.board_group_memory import BoardGroupMemoryCreate, BoardGroupMemoryRead
|
||||||
BoardGroupMemoryCreate,
|
|
||||||
BoardGroupMemoryRead,
|
|
||||||
)
|
|
||||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||||
from app.services.mentions import extract_mentions, matches_agent_mention
|
from app.services.mentions import extract_mentions, matches_agent_mention
|
||||||
from app.services.organizations import (
|
from app.services.organizations import (
|
||||||
|
|||||||
@@ -10,12 +10,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
|||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
|
|
||||||
from app.api.deps import (
|
from app.api.deps import ActorContext, require_admin_or_agent, require_org_admin, require_org_member
|
||||||
ActorContext,
|
|
||||||
require_admin_or_agent,
|
|
||||||
require_org_admin,
|
|
||||||
require_org_member,
|
|
||||||
)
|
|
||||||
from app.core.time import utcnow
|
from app.core.time import utcnow
|
||||||
from app.db import crud
|
from app.db import crud
|
||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
@@ -34,10 +29,7 @@ from app.schemas.board_groups import BoardGroupCreate, BoardGroupRead, BoardGrou
|
|||||||
from app.schemas.common import OkResponse
|
from app.schemas.common import OkResponse
|
||||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||||
from app.schemas.view_models import BoardGroupSnapshot
|
from app.schemas.view_models import BoardGroupSnapshot
|
||||||
from app.services.agent_provisioning import (
|
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, sync_gateway_agent_heartbeats
|
||||||
DEFAULT_HEARTBEAT_CONFIG,
|
|
||||||
sync_gateway_agent_heartbeats,
|
|
||||||
)
|
|
||||||
from app.services.board_group_snapshot import build_group_snapshot
|
from app.services.board_group_snapshot import build_group_snapshot
|
||||||
from app.services.organizations import (
|
from app.services.organizations import (
|
||||||
OrganizationContext,
|
OrganizationContext,
|
||||||
@@ -86,8 +78,7 @@ async def _require_group_access(
|
|||||||
return group
|
return group
|
||||||
|
|
||||||
board_ids = [
|
board_ids = [
|
||||||
board.id
|
board.id for board in await Board.objects.filter_by(board_group_id=group_id).all(session)
|
||||||
for board in await Board.objects.filter_by(board_group_id=group_id).all(session)
|
|
||||||
]
|
]
|
||||||
if not board_ids:
|
if not board_ids:
|
||||||
if is_org_admin(member):
|
if is_org_admin(member):
|
||||||
@@ -144,7 +135,10 @@ async def get_board_group(
|
|||||||
) -> BoardGroup:
|
) -> BoardGroup:
|
||||||
"""Get a board group by id."""
|
"""Get a board group by id."""
|
||||||
return await _require_group_access(
|
return await _require_group_access(
|
||||||
session, group_id=group_id, member=ctx.member, write=False,
|
session,
|
||||||
|
group_id=group_id,
|
||||||
|
member=ctx.member,
|
||||||
|
write=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -159,7 +153,10 @@ async def get_board_group_snapshot(
|
|||||||
) -> BoardGroupSnapshot:
|
) -> BoardGroupSnapshot:
|
||||||
"""Get a snapshot across boards in a group."""
|
"""Get a snapshot across boards in a group."""
|
||||||
group = await _require_group_access(
|
group = await _require_group_access(
|
||||||
session, group_id=group_id, member=ctx.member, write=False,
|
session,
|
||||||
|
group_id=group_id,
|
||||||
|
member=ctx.member,
|
||||||
|
write=False,
|
||||||
)
|
)
|
||||||
if per_board_task_limit < 0:
|
if per_board_task_limit < 0:
|
||||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||||
@@ -174,9 +171,7 @@ async def get_board_group_snapshot(
|
|||||||
allowed_ids = set(
|
allowed_ids = set(
|
||||||
await list_accessible_board_ids(session, member=ctx.member, write=False),
|
await list_accessible_board_ids(session, member=ctx.member, write=False),
|
||||||
)
|
)
|
||||||
snapshot.boards = [
|
snapshot.boards = [item for item in snapshot.boards if item.board.id in allowed_ids]
|
||||||
item for item in snapshot.boards if item.board.id in allowed_ids
|
|
||||||
]
|
|
||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
@@ -339,14 +334,13 @@ async def update_board_group(
|
|||||||
) -> BoardGroup:
|
) -> BoardGroup:
|
||||||
"""Update a board group."""
|
"""Update a board group."""
|
||||||
group = await _require_group_access(
|
group = await _require_group_access(
|
||||||
session, group_id=group_id, member=ctx.member, write=True,
|
session,
|
||||||
|
group_id=group_id,
|
||||||
|
member=ctx.member,
|
||||||
|
write=True,
|
||||||
)
|
)
|
||||||
updates = payload.model_dump(exclude_unset=True)
|
updates = payload.model_dump(exclude_unset=True)
|
||||||
if (
|
if "slug" in updates and updates["slug"] is not None and not updates["slug"].strip():
|
||||||
"slug" in updates
|
|
||||||
and updates["slug"] is not None
|
|
||||||
and not updates["slug"].strip()
|
|
||||||
):
|
|
||||||
updates["slug"] = _slugify(updates.get("name") or group.name)
|
updates["slug"] = _slugify(updates.get("name") or group.name)
|
||||||
updates["updated_at"] = utcnow()
|
updates["updated_at"] = utcnow()
|
||||||
return await crud.patch(session, group, updates)
|
return await crud.patch(session, group, updates)
|
||||||
@@ -360,7 +354,10 @@ async def delete_board_group(
|
|||||||
) -> OkResponse:
|
) -> OkResponse:
|
||||||
"""Delete a board group."""
|
"""Delete a board group."""
|
||||||
await _require_group_access(
|
await _require_group_access(
|
||||||
session, group_id=group_id, member=ctx.member, write=True,
|
session,
|
||||||
|
group_id=group_id,
|
||||||
|
member=ctx.member,
|
||||||
|
write=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Boards reference groups, so clear the FK first to keep deletes simple.
|
# Boards reference groups, so clear the FK first to keep deletes simple.
|
||||||
@@ -378,7 +375,10 @@ async def delete_board_group(
|
|||||||
commit=False,
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, BoardGroup, col(BoardGroup.id) == group_id, commit=False,
|
session,
|
||||||
|
BoardGroup,
|
||||||
|
col(BoardGroup.id) == group_id,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return OkResponse()
|
return OkResponse()
|
||||||
|
|||||||
@@ -24,11 +24,7 @@ from app.core.time import utcnow
|
|||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
from app.db.session import async_session_maker, get_session
|
from app.db.session import async_session_maker, get_session
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import (
|
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||||
OpenClawGatewayError,
|
|
||||||
ensure_session,
|
|
||||||
send_message,
|
|
||||||
)
|
|
||||||
from app.models.agents import Agent
|
from app.models.agents import Agent
|
||||||
from app.models.board_memory import BoardMemory
|
from app.models.board_memory import BoardMemory
|
||||||
from app.models.gateways import Gateway
|
from app.models.gateways import Gateway
|
||||||
|
|||||||
@@ -21,11 +21,7 @@ from app.core.config import settings
|
|||||||
from app.core.time import utcnow
|
from app.core.time import utcnow
|
||||||
from app.db.session import get_session
|
from app.db.session import get_session
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import (
|
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||||
OpenClawGatewayError,
|
|
||||||
ensure_session,
|
|
||||||
send_message,
|
|
||||||
)
|
|
||||||
from app.models.board_onboarding import BoardOnboardingSession
|
from app.models.board_onboarding import BoardOnboardingSession
|
||||||
from app.models.gateways import Gateway
|
from app.models.gateways import Gateway
|
||||||
from app.schemas.board_onboarding import (
|
from app.schemas.board_onboarding import (
|
||||||
@@ -39,11 +35,7 @@ from app.schemas.board_onboarding import (
|
|||||||
BoardOnboardingUserProfile,
|
BoardOnboardingUserProfile,
|
||||||
)
|
)
|
||||||
from app.schemas.boards import BoardRead
|
from app.schemas.boards import BoardRead
|
||||||
from app.services.board_leads import (
|
from app.services.board_leads import LeadAgentOptions, LeadAgentRequest, ensure_board_lead_agent
|
||||||
LeadAgentOptions,
|
|
||||||
LeadAgentRequest,
|
|
||||||
ensure_board_lead_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -62,7 +54,8 @@ ADMIN_AUTH_DEP = Depends(require_admin_auth)
|
|||||||
|
|
||||||
|
|
||||||
async def _gateway_config(
|
async def _gateway_config(
|
||||||
session: AsyncSession, board: Board,
|
session: AsyncSession,
|
||||||
|
board: Board,
|
||||||
) -> tuple[Gateway, GatewayClientConfig]:
|
) -> tuple[Gateway, GatewayClientConfig]:
|
||||||
if not board.gateway_id:
|
if not board.gateway_id:
|
||||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||||
@@ -255,11 +248,15 @@ async def start_onboarding(
|
|||||||
try:
|
try:
|
||||||
await ensure_session(session_key, config=config, label="Main Agent")
|
await ensure_session(session_key, config=config, label="Main Agent")
|
||||||
await send_message(
|
await send_message(
|
||||||
prompt, session_key=session_key, config=config, deliver=False,
|
prompt,
|
||||||
|
session_key=session_key,
|
||||||
|
config=config,
|
||||||
|
deliver=False,
|
||||||
)
|
)
|
||||||
except OpenClawGatewayError as exc:
|
except OpenClawGatewayError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
onboarding = BoardOnboardingSession(
|
onboarding = BoardOnboardingSession(
|
||||||
@@ -311,7 +308,8 @@ async def answer_onboarding(
|
|||||||
)
|
)
|
||||||
except OpenClawGatewayError as exc:
|
except OpenClawGatewayError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
onboarding.messages = messages
|
onboarding.messages = messages
|
||||||
|
|||||||
@@ -104,7 +104,9 @@ async def _require_gateway_for_create(
|
|||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
) -> Gateway:
|
) -> Gateway:
|
||||||
return await _require_gateway(
|
return await _require_gateway(
|
||||||
session, payload.gateway_id, organization_id=ctx.organization.id,
|
session,
|
||||||
|
payload.gateway_id,
|
||||||
|
organization_id=ctx.organization.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -155,7 +157,9 @@ async def _apply_board_update(
|
|||||||
updates = payload.model_dump(exclude_unset=True)
|
updates = payload.model_dump(exclude_unset=True)
|
||||||
if "gateway_id" in updates:
|
if "gateway_id" in updates:
|
||||||
await _require_gateway(
|
await _require_gateway(
|
||||||
session, updates["gateway_id"], organization_id=board.organization_id,
|
session,
|
||||||
|
updates["gateway_id"],
|
||||||
|
organization_id=board.organization_id,
|
||||||
)
|
)
|
||||||
if "board_group_id" in updates and updates["board_group_id"] is not None:
|
if "board_group_id" in updates and updates["board_group_id"] is not None:
|
||||||
await _require_board_group(
|
await _require_board_group(
|
||||||
@@ -164,10 +168,7 @@ async def _apply_board_update(
|
|||||||
organization_id=board.organization_id,
|
organization_id=board.organization_id,
|
||||||
)
|
)
|
||||||
crud.apply_updates(board, updates)
|
crud.apply_updates(board, updates)
|
||||||
if (
|
if updates.get("board_type") == "goal" and (not board.objective or not board.success_metrics):
|
||||||
updates.get("board_type") == "goal"
|
|
||||||
and (not board.objective or not board.success_metrics)
|
|
||||||
):
|
|
||||||
# Validate only when explicitly switching to goal boards.
|
# Validate only when explicitly switching to goal boards.
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
@@ -183,7 +184,8 @@ async def _apply_board_update(
|
|||||||
|
|
||||||
|
|
||||||
async def _board_gateway(
|
async def _board_gateway(
|
||||||
session: AsyncSession, board: Board,
|
session: AsyncSession,
|
||||||
|
board: Board,
|
||||||
) -> tuple[Gateway | None, GatewayClientConfig | None]:
|
) -> tuple[Gateway | None, GatewayClientConfig | None]:
|
||||||
if not board.gateway_id:
|
if not board.gateway_id:
|
||||||
return None, None
|
return None, None
|
||||||
@@ -255,7 +257,8 @@ async def list_boards(
|
|||||||
if board_group_id is not None:
|
if board_group_id is not None:
|
||||||
statement = statement.where(col(Board.board_group_id) == board_group_id)
|
statement = statement.where(col(Board.board_group_id) == board_group_id)
|
||||||
statement = statement.order_by(
|
statement = statement.order_by(
|
||||||
func.lower(col(Board.name)).asc(), col(Board.created_at).desc(),
|
func.lower(col(Board.name)).asc(),
|
||||||
|
col(Board.created_at).desc(),
|
||||||
)
|
)
|
||||||
return await paginate(session, statement)
|
return await paginate(session, statement)
|
||||||
|
|
||||||
@@ -350,10 +353,14 @@ async def delete_board(
|
|||||||
commit=False,
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, TaskDependency, col(TaskDependency.board_id) == board.id,
|
session,
|
||||||
|
TaskDependency,
|
||||||
|
col(TaskDependency.board_id) == board.id,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, TaskFingerprint, col(TaskFingerprint.board_id) == board.id,
|
session,
|
||||||
|
TaskFingerprint,
|
||||||
|
col(TaskFingerprint.board_id) == board.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Approvals can reference tasks and agents, so delete before both.
|
# Approvals can reference tasks and agents, so delete before both.
|
||||||
|
|||||||
@@ -91,7 +91,8 @@ async def _resolve_gateway(
|
|||||||
board = await Board.objects.by_id(params.board_id).first(session)
|
board = await Board.objects.by_id(params.board_id).first(session)
|
||||||
if board is None:
|
if board is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Board not found",
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Board not found",
|
||||||
)
|
)
|
||||||
if user is not None:
|
if user is not None:
|
||||||
await require_board_access(session, user=user, board=board, write=False)
|
await require_board_access(session, user=user, board=board, write=False)
|
||||||
@@ -119,7 +120,10 @@ async def _resolve_gateway(
|
|||||||
|
|
||||||
|
|
||||||
async def _require_gateway(
|
async def _require_gateway(
|
||||||
session: AsyncSession, board_id: str | None, *, user: User | None = None,
|
session: AsyncSession,
|
||||||
|
board_id: str | None,
|
||||||
|
*,
|
||||||
|
user: User | None = None,
|
||||||
) -> tuple[Board, GatewayClientConfig, str | None]:
|
) -> tuple[Board, GatewayClientConfig, str | None]:
|
||||||
params = GatewayResolveQuery(board_id=board_id)
|
params = GatewayResolveQuery(board_id=board_id)
|
||||||
board, config, main_session = await _resolve_gateway(
|
board, config, main_session = await _resolve_gateway(
|
||||||
@@ -161,7 +165,9 @@ async def gateways_status(
|
|||||||
if main_session:
|
if main_session:
|
||||||
try:
|
try:
|
||||||
ensured = await ensure_session(
|
ensured = await ensure_session(
|
||||||
main_session, config=config, label="Main Agent",
|
main_session,
|
||||||
|
config=config,
|
||||||
|
label="Main Agent",
|
||||||
)
|
)
|
||||||
if isinstance(ensured, dict):
|
if isinstance(ensured, dict):
|
||||||
main_session_entry = ensured.get("entry") or ensured
|
main_session_entry = ensured.get("entry") or ensured
|
||||||
@@ -178,7 +184,9 @@ async def gateways_status(
|
|||||||
)
|
)
|
||||||
except OpenClawGatewayError as exc:
|
except OpenClawGatewayError as exc:
|
||||||
return GatewaysStatusResponse(
|
return GatewaysStatusResponse(
|
||||||
connected=False, gateway_url=config.url, error=str(exc),
|
connected=False,
|
||||||
|
gateway_url=config.url,
|
||||||
|
error=str(exc),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -202,7 +210,8 @@ async def list_gateway_sessions(
|
|||||||
sessions = await openclaw_call("sessions.list", config=config)
|
sessions = await openclaw_call("sessions.list", config=config)
|
||||||
except OpenClawGatewayError as exc:
|
except OpenClawGatewayError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
if isinstance(sessions, dict):
|
if isinstance(sessions, dict):
|
||||||
sessions_list = _as_object_list(sessions.get("sessions"))
|
sessions_list = _as_object_list(sessions.get("sessions"))
|
||||||
@@ -213,7 +222,9 @@ async def list_gateway_sessions(
|
|||||||
if main_session:
|
if main_session:
|
||||||
try:
|
try:
|
||||||
ensured = await ensure_session(
|
ensured = await ensure_session(
|
||||||
main_session, config=config, label="Main Agent",
|
main_session,
|
||||||
|
config=config,
|
||||||
|
label="Main Agent",
|
||||||
)
|
)
|
||||||
if isinstance(ensured, dict):
|
if isinstance(ensured, dict):
|
||||||
main_session_entry = ensured.get("entry") or ensured
|
main_session_entry = ensured.get("entry") or ensured
|
||||||
@@ -233,11 +244,7 @@ async def _list_sessions(config: GatewayClientConfig) -> list[dict[str, object]]
|
|||||||
raw_items = _as_object_list(sessions.get("sessions"))
|
raw_items = _as_object_list(sessions.get("sessions"))
|
||||||
else:
|
else:
|
||||||
raw_items = _as_object_list(sessions)
|
raw_items = _as_object_list(sessions)
|
||||||
return [
|
return [item for item in raw_items if isinstance(item, dict)]
|
||||||
item
|
|
||||||
for item in raw_items
|
|
||||||
if isinstance(item, dict)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def _with_main_session(
|
async def _with_main_session(
|
||||||
@@ -246,9 +253,7 @@ async def _with_main_session(
|
|||||||
config: GatewayClientConfig,
|
config: GatewayClientConfig,
|
||||||
main_session: str | None,
|
main_session: str | None,
|
||||||
) -> list[dict[str, object]]:
|
) -> list[dict[str, object]]:
|
||||||
if not main_session or any(
|
if not main_session or any(item.get("key") == main_session for item in sessions_list):
|
||||||
item.get("key") == main_session for item in sessions_list
|
|
||||||
):
|
|
||||||
return sessions_list
|
return sessions_list
|
||||||
try:
|
try:
|
||||||
await ensure_session(main_session, config=config, label="Main Agent")
|
await ensure_session(main_session, config=config, label="Main Agent")
|
||||||
@@ -278,7 +283,8 @@ async def get_gateway_session(
|
|||||||
sessions_list = await _list_sessions(config)
|
sessions_list = await _list_sessions(config)
|
||||||
except OpenClawGatewayError as exc:
|
except OpenClawGatewayError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
sessions_list = await _with_main_session(
|
sessions_list = await _with_main_session(
|
||||||
sessions_list,
|
sessions_list,
|
||||||
@@ -286,12 +292,15 @@ async def get_gateway_session(
|
|||||||
main_session=main_session,
|
main_session=main_session,
|
||||||
)
|
)
|
||||||
session_entry = next(
|
session_entry = next(
|
||||||
(item for item in sessions_list if item.get("key") == session_id), None,
|
(item for item in sessions_list if item.get("key") == session_id),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
if session_entry is None and main_session and session_id == main_session:
|
if session_entry is None and main_session and session_id == main_session:
|
||||||
try:
|
try:
|
||||||
ensured = await ensure_session(
|
ensured = await ensure_session(
|
||||||
main_session, config=config, label="Main Agent",
|
main_session,
|
||||||
|
config=config,
|
||||||
|
label="Main Agent",
|
||||||
)
|
)
|
||||||
if isinstance(ensured, dict):
|
if isinstance(ensured, dict):
|
||||||
session_entry = ensured.get("entry") or ensured
|
session_entry = ensured.get("entry") or ensured
|
||||||
@@ -299,13 +308,15 @@ async def get_gateway_session(
|
|||||||
session_entry = None
|
session_entry = None
|
||||||
if session_entry is None:
|
if session_entry is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found",
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Session not found",
|
||||||
)
|
)
|
||||||
return GatewaySessionResponse(session=session_entry)
|
return GatewaySessionResponse(session=session_entry)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}/history", response_model=GatewaySessionHistoryResponse,
|
"/sessions/{session_id}/history",
|
||||||
|
response_model=GatewaySessionHistoryResponse,
|
||||||
)
|
)
|
||||||
async def get_session_history(
|
async def get_session_history(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
@@ -322,7 +333,8 @@ async def get_session_history(
|
|||||||
history = await get_chat_history(session_id, config=config)
|
history = await get_chat_history(session_id, config=config)
|
||||||
except OpenClawGatewayError as exc:
|
except OpenClawGatewayError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
if isinstance(history, dict) and isinstance(history.get("messages"), list):
|
if isinstance(history, dict) and isinstance(history.get("messages"), list):
|
||||||
return GatewaySessionHistoryResponse(history=history["messages"])
|
return GatewaySessionHistoryResponse(history=history["messages"])
|
||||||
@@ -339,7 +351,9 @@ async def send_gateway_session_message(
|
|||||||
) -> OkResponse:
|
) -> OkResponse:
|
||||||
"""Send a message into a specific gateway session."""
|
"""Send a message into a specific gateway session."""
|
||||||
board, config, main_session = await _require_gateway(
|
board, config, main_session = await _require_gateway(
|
||||||
session, board_id, user=auth.user,
|
session,
|
||||||
|
board_id,
|
||||||
|
user=auth.user,
|
||||||
)
|
)
|
||||||
if auth.user is None:
|
if auth.user is None:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||||
@@ -350,7 +364,8 @@ async def send_gateway_session_message(
|
|||||||
await send_message(payload.content, session_key=session_id, config=config)
|
await send_message(payload.content, session_key=session_id, config=config)
|
||||||
except OpenClawGatewayError as exc:
|
except OpenClawGatewayError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc),
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
return OkResponse()
|
return OkResponse()
|
||||||
|
|
||||||
|
|||||||
@@ -17,11 +17,7 @@ from app.db import crud
|
|||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
from app.db.session import get_session
|
from app.db.session import get_session
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import (
|
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||||
OpenClawGatewayError,
|
|
||||||
ensure_session,
|
|
||||||
send_message,
|
|
||||||
)
|
|
||||||
from app.models.agents import Agent
|
from app.models.agents import Agent
|
||||||
from app.models.gateways import Gateway
|
from app.models.gateways import Gateway
|
||||||
from app.schemas.common import OkResponse
|
from app.schemas.common import OkResponse
|
||||||
@@ -38,12 +34,8 @@ from app.services.agent_provisioning import (
|
|||||||
ProvisionOptions,
|
ProvisionOptions,
|
||||||
provision_main_agent,
|
provision_main_agent,
|
||||||
)
|
)
|
||||||
from app.services.template_sync import (
|
from app.services.template_sync import GatewayTemplateSyncOptions
|
||||||
GatewayTemplateSyncOptions,
|
from app.services.template_sync import sync_gateway_templates as sync_gateway_templates_service
|
||||||
)
|
|
||||||
from app.services.template_sync import (
|
|
||||||
sync_gateway_templates as sync_gateway_templates_service,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastapi_pagination.limit_offset import LimitOffsetPage
|
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||||
@@ -109,7 +101,8 @@ async def _require_gateway(
|
|||||||
)
|
)
|
||||||
if gateway is None:
|
if gateway is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found",
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Gateway not found",
|
||||||
)
|
)
|
||||||
return gateway
|
return gateway
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,9 @@ from typing import Literal
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy import DateTime, case, func
|
from sqlalchemy import DateTime, case
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
|
from sqlalchemy import func
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|||||||
@@ -80,7 +80,8 @@ ORG_ADMIN_DEP = Depends(require_org_admin)
|
|||||||
|
|
||||||
|
|
||||||
def _member_to_read(
|
def _member_to_read(
|
||||||
member: OrganizationMember, user: User | None,
|
member: OrganizationMember,
|
||||||
|
user: User | None,
|
||||||
) -> OrganizationMemberRead:
|
) -> OrganizationMemberRead:
|
||||||
model = OrganizationMemberRead.model_validate(member, from_attributes=True)
|
model = OrganizationMemberRead.model_validate(member, from_attributes=True)
|
||||||
if user is not None:
|
if user is not None:
|
||||||
@@ -167,9 +168,7 @@ async def list_my_organizations(
|
|||||||
|
|
||||||
await get_active_membership(session, auth.user)
|
await get_active_membership(session, auth.user)
|
||||||
db_user = await User.objects.by_id(auth.user.id).first(session)
|
db_user = await User.objects.by_id(auth.user.id).first(session)
|
||||||
active_id = (
|
active_id = db_user.active_organization_id if db_user else auth.user.active_organization_id
|
||||||
db_user.active_organization_id if db_user else auth.user.active_organization_id
|
|
||||||
)
|
|
||||||
|
|
||||||
statement = (
|
statement = (
|
||||||
select(Organization, OrganizationMember)
|
select(Organization, OrganizationMember)
|
||||||
@@ -202,7 +201,9 @@ async def set_active_org(
|
|||||||
if auth.user is None:
|
if auth.user is None:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||||
member = await set_active_organization(
|
member = await set_active_organization(
|
||||||
session, user=auth.user, organization_id=payload.organization_id,
|
session,
|
||||||
|
user=auth.user,
|
||||||
|
organization_id=payload.organization_id,
|
||||||
)
|
)
|
||||||
organization = await Organization.objects.by_id(member.organization_id).first(
|
organization = await Organization.objects.by_id(member.organization_id).first(
|
||||||
session,
|
session,
|
||||||
@@ -245,7 +246,10 @@ async def delete_my_org(
|
|||||||
group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id)
|
group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id)
|
||||||
|
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False,
|
session,
|
||||||
|
ActivityEvent,
|
||||||
|
col(ActivityEvent.task_id).in_(task_ids),
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
@@ -266,10 +270,16 @@ async def delete_my_org(
|
|||||||
commit=False,
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, Approval, col(Approval.board_id).in_(board_ids), commit=False,
|
session,
|
||||||
|
Approval,
|
||||||
|
col(Approval.board_id).in_(board_ids),
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, BoardMemory, col(BoardMemory.board_id).in_(board_ids), commit=False,
|
session,
|
||||||
|
BoardMemory,
|
||||||
|
col(BoardMemory.board_id).in_(board_ids),
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
@@ -302,13 +312,22 @@ async def delete_my_org(
|
|||||||
commit=False,
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, Task, col(Task.board_id).in_(board_ids), commit=False,
|
session,
|
||||||
|
Task,
|
||||||
|
col(Task.board_id).in_(board_ids),
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, Agent, col(Agent.board_id).in_(board_ids), commit=False,
|
session,
|
||||||
|
Agent,
|
||||||
|
col(Agent.board_id).in_(board_ids),
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, Board, col(Board.organization_id) == org_id, commit=False,
|
session,
|
||||||
|
Board,
|
||||||
|
col(Board.organization_id) == org_id,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
@@ -317,10 +336,16 @@ async def delete_my_org(
|
|||||||
commit=False,
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, BoardGroup, col(BoardGroup.organization_id) == org_id, commit=False,
|
session,
|
||||||
|
BoardGroup,
|
||||||
|
col(BoardGroup.organization_id) == org_id,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, Gateway, col(Gateway.organization_id) == org_id, commit=False,
|
session,
|
||||||
|
Gateway,
|
||||||
|
col(Gateway.organization_id) == org_id,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
@@ -342,7 +367,10 @@ async def delete_my_org(
|
|||||||
commit=False,
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, Organization, col(Organization.id) == org_id, commit=False,
|
session,
|
||||||
|
Organization,
|
||||||
|
col(Organization.id) == org_id,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return OkResponse()
|
return OkResponse()
|
||||||
@@ -360,14 +388,14 @@ async def get_my_membership(
|
|||||||
).all(session)
|
).all(session)
|
||||||
model = _member_to_read(ctx.member, user)
|
model = _member_to_read(ctx.member, user)
|
||||||
model.board_access = [
|
model.board_access = [
|
||||||
OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
|
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
|
||||||
for row in access_rows
|
|
||||||
]
|
]
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/me/members", response_model=DefaultLimitOffsetPage[OrganizationMemberRead],
|
"/me/members",
|
||||||
|
response_model=DefaultLimitOffsetPage[OrganizationMemberRead],
|
||||||
)
|
)
|
||||||
async def list_org_members(
|
async def list_org_members(
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
@@ -410,8 +438,7 @@ async def get_org_member(
|
|||||||
).all(session)
|
).all(session)
|
||||||
model = _member_to_read(member, user)
|
model = _member_to_read(member, user)
|
||||||
model.board_access = [
|
model.board_access = [
|
||||||
OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
|
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) for row in access_rows
|
||||||
for row in access_rows
|
|
||||||
]
|
]
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -529,9 +556,7 @@ async def remove_org_member(
|
|||||||
user.active_organization_id = fallback_membership
|
user.active_organization_id = fallback_membership
|
||||||
else:
|
else:
|
||||||
user.active_organization_id = (
|
user.active_organization_id = (
|
||||||
fallback_membership.organization_id
|
fallback_membership.organization_id if fallback_membership is not None else None
|
||||||
if fallback_membership is not None
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
session.add(user)
|
session.add(user)
|
||||||
|
|
||||||
@@ -540,7 +565,8 @@ async def remove_org_member(
|
|||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/me/invites", response_model=DefaultLimitOffsetPage[OrganizationInviteRead],
|
"/me/invites",
|
||||||
|
response_model=DefaultLimitOffsetPage[OrganizationInviteRead],
|
||||||
)
|
)
|
||||||
async def list_org_invites(
|
async def list_org_invites(
|
||||||
session: AsyncSession = SESSION_DEP,
|
session: AsyncSession = SESSION_DEP,
|
||||||
@@ -607,7 +633,9 @@ async def create_org_invite(
|
|||||||
if valid_board_ids != board_ids:
|
if valid_board_ids != board_ids:
|
||||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||||
await apply_invite_board_access(
|
await apply_invite_board_access(
|
||||||
session, invite=invite, entries=payload.board_access,
|
session,
|
||||||
|
invite=invite,
|
||||||
|
entries=payload.board_access,
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(invite)
|
await session.refresh(invite)
|
||||||
|
|||||||
@@ -29,11 +29,7 @@ from app.db import crud
|
|||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
from app.db.session import async_session_maker, get_session
|
from app.db.session import async_session_maker, get_session
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import (
|
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||||
OpenClawGatewayError,
|
|
||||||
ensure_session,
|
|
||||||
send_message,
|
|
||||||
)
|
|
||||||
from app.models.activity_events import ActivityEvent
|
from app.models.activity_events import ActivityEvent
|
||||||
from app.models.agents import Agent
|
from app.models.agents import Agent
|
||||||
from app.models.approvals import Approval
|
from app.models.approvals import Approval
|
||||||
@@ -45,13 +41,7 @@ from app.models.tasks import Task
|
|||||||
from app.schemas.common import OkResponse
|
from app.schemas.common import OkResponse
|
||||||
from app.schemas.errors import BlockedTaskError
|
from app.schemas.errors import BlockedTaskError
|
||||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||||
from app.schemas.tasks import (
|
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
|
||||||
TaskCommentCreate,
|
|
||||||
TaskCommentRead,
|
|
||||||
TaskCreate,
|
|
||||||
TaskRead,
|
|
||||||
TaskUpdate,
|
|
||||||
)
|
|
||||||
from app.services.activity_log import record_activity
|
from app.services.activity_log import record_activity
|
||||||
from app.services.mentions import extract_mentions, matches_agent_mention
|
from app.services.mentions import extract_mentions, matches_agent_mention
|
||||||
from app.services.organizations import require_board_access
|
from app.services.organizations import require_board_access
|
||||||
@@ -263,8 +253,7 @@ async def _reconcile_dependents_for_dependency_toggle(
|
|||||||
event_type="task.status_changed",
|
event_type="task.status_changed",
|
||||||
task_id=dependent.id,
|
task_id=dependent.id,
|
||||||
message=(
|
message=(
|
||||||
"Task returned to inbox: dependency reopened "
|
"Task returned to inbox: dependency reopened " f"({dependency_task.title})."
|
||||||
f"({dependency_task.title})."
|
|
||||||
),
|
),
|
||||||
agent_id=actor_agent_id,
|
agent_id=actor_agent_id,
|
||||||
)
|
)
|
||||||
@@ -313,7 +302,8 @@ def _serialize_comment(event: ActivityEvent) -> dict[str, object]:
|
|||||||
|
|
||||||
|
|
||||||
async def _gateway_config(
|
async def _gateway_config(
|
||||||
session: AsyncSession, board: Board,
|
session: AsyncSession,
|
||||||
|
board: Board,
|
||||||
) -> GatewayClientConfig | None:
|
) -> GatewayClientConfig | None:
|
||||||
if not board.gateway_id:
|
if not board.gateway_id:
|
||||||
return None
|
return None
|
||||||
@@ -368,10 +358,7 @@ async def _notify_agent_on_task_assign(
|
|||||||
message = (
|
message = (
|
||||||
"TASK ASSIGNED\n"
|
"TASK ASSIGNED\n"
|
||||||
+ "\n".join(details)
|
+ "\n".join(details)
|
||||||
+ (
|
+ ("\n\nTake action: open the task and begin work. " "Post updates as task comments.")
|
||||||
"\n\nTake action: open the task and begin work. "
|
|
||||||
"Post updates as task comments."
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await _send_agent_task_message(
|
await _send_agent_task_message(
|
||||||
@@ -607,9 +594,7 @@ async def _stream_dependency_state(
|
|||||||
rows: list[tuple[ActivityEvent, Task | None]],
|
rows: list[tuple[ActivityEvent, Task | None]],
|
||||||
) -> tuple[dict[UUID, list[UUID]], dict[UUID, str]]:
|
) -> tuple[dict[UUID, list[UUID]], dict[UUID, str]]:
|
||||||
task_ids = [
|
task_ids = [
|
||||||
task.id
|
task.id for event, task in rows if task is not None and event.event_type != "task.comment"
|
||||||
for event, task in rows
|
|
||||||
if task is not None and event.event_type != "task.comment"
|
|
||||||
]
|
]
|
||||||
if not task_ids:
|
if not task_ids:
|
||||||
return {}, {}
|
return {}, {}
|
||||||
@@ -786,7 +771,8 @@ async def create_task(
|
|||||||
dependency_ids=normalized_deps,
|
dependency_ids=normalized_deps,
|
||||||
)
|
)
|
||||||
blocked_by = blocked_by_dependency_ids(
|
blocked_by = blocked_by_dependency_ids(
|
||||||
dependency_ids=normalized_deps, status_by_id=dep_status,
|
dependency_ids=normalized_deps,
|
||||||
|
status_by_id=dep_status,
|
||||||
)
|
)
|
||||||
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
|
if blocked_by and (task.assigned_agent_id is not None or task.status != "inbox"):
|
||||||
raise _blocked_task_error(blocked_by)
|
raise _blocked_task_error(blocked_by)
|
||||||
@@ -861,9 +847,7 @@ async def update_task(
|
|||||||
updates = payload.model_dump(exclude_unset=True)
|
updates = payload.model_dump(exclude_unset=True)
|
||||||
comment = payload.comment if "comment" in payload.model_fields_set else None
|
comment = payload.comment if "comment" in payload.model_fields_set else None
|
||||||
depends_on_task_ids = (
|
depends_on_task_ids = (
|
||||||
payload.depends_on_task_ids
|
payload.depends_on_task_ids if "depends_on_task_ids" in payload.model_fields_set else None
|
||||||
if "depends_on_task_ids" in payload.model_fields_set
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
updates.pop("comment", None)
|
updates.pop("comment", None)
|
||||||
updates.pop("depends_on_task_ids", None)
|
updates.pop("depends_on_task_ids", None)
|
||||||
@@ -906,13 +890,22 @@ async def delete_task(
|
|||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||||
await require_board_access(session, user=auth.user, board=board, write=True)
|
await require_board_access(session, user=auth.user, board=board, write=True)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, ActivityEvent, col(ActivityEvent.task_id) == task.id, commit=False,
|
session,
|
||||||
|
ActivityEvent,
|
||||||
|
col(ActivityEvent.task_id) == task.id,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, TaskFingerprint, col(TaskFingerprint.task_id) == task.id, commit=False,
|
session,
|
||||||
|
TaskFingerprint,
|
||||||
|
col(TaskFingerprint.task_id) == task.id,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session, Approval, col(Approval.task_id) == task.id, commit=False,
|
session,
|
||||||
|
Approval,
|
||||||
|
col(Approval.task_id) == task.id,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete_where(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
@@ -929,7 +922,8 @@ async def delete_task(
|
|||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/{task_id}/comments", response_model=DefaultLimitOffsetPage[TaskCommentRead],
|
"/{task_id}/comments",
|
||||||
|
response_model=DefaultLimitOffsetPage[TaskCommentRead],
|
||||||
)
|
)
|
||||||
async def list_task_comments(
|
async def list_task_comments(
|
||||||
task: Task = TASK_DEP,
|
task: Task = TASK_DEP,
|
||||||
@@ -1241,11 +1235,7 @@ async def _lead_apply_assignment(
|
|||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Board leads cannot assign tasks to themselves.",
|
detail="Board leads cannot assign tasks to themselves.",
|
||||||
)
|
)
|
||||||
if (
|
if agent.board_id and update.task.board_id and agent.board_id != update.task.board_id:
|
||||||
agent.board_id
|
|
||||||
and update.task.board_id
|
|
||||||
and agent.board_id != update.task.board_id
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
|
||||||
update.task.assigned_agent_id = agent.id
|
update.task.assigned_agent_id = agent.id
|
||||||
|
|
||||||
@@ -1256,19 +1246,13 @@ def _lead_apply_status(update: _TaskUpdateInput) -> None:
|
|||||||
if update.task.status != "review":
|
if update.task.status != "review":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=(
|
detail=("Board leads can only change status when a task is " "in review."),
|
||||||
"Board leads can only change status when a task is "
|
|
||||||
"in review."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
target_status = _required_status_value(update.updates["status"])
|
target_status = _required_status_value(update.updates["status"])
|
||||||
if target_status not in {"done", "inbox"}:
|
if target_status not in {"done", "inbox"}:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=(
|
detail=("Board leads can only move review tasks to done " "or inbox."),
|
||||||
"Board leads can only move review tasks to done "
|
|
||||||
"or inbox."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if target_status == "inbox":
|
if target_status == "inbox":
|
||||||
update.task.assigned_agent_id = None
|
update.task.assigned_agent_id = None
|
||||||
@@ -1397,9 +1381,7 @@ async def _apply_non_lead_agent_task_rules(
|
|||||||
update.task.assigned_agent_id = None
|
update.task.assigned_agent_id = None
|
||||||
update.task.in_progress_at = None
|
update.task.in_progress_at = None
|
||||||
else:
|
else:
|
||||||
update.task.assigned_agent_id = (
|
update.task.assigned_agent_id = update.actor.agent.id if update.actor.agent else None
|
||||||
update.actor.agent.id if update.actor.agent else None
|
|
||||||
)
|
|
||||||
if status_value == "in_progress":
|
if status_value == "in_progress":
|
||||||
update.task.in_progress_at = utcnow()
|
update.task.in_progress_at = utcnow()
|
||||||
|
|
||||||
@@ -1462,11 +1444,7 @@ async def _apply_admin_task_rules(
|
|||||||
agent = await Agent.objects.by_id(assigned_agent_id).first(session)
|
agent = await Agent.objects.by_id(assigned_agent_id).first(session)
|
||||||
if agent is None:
|
if agent is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||||
if (
|
if agent.board_id and update.task.board_id and agent.board_id != update.task.board_id:
|
||||||
agent.board_id
|
|
||||||
and update.task.board_id
|
|
||||||
and agent.board_id != update.task.board_id
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
|
||||||
|
|
||||||
|
|
||||||
@@ -1481,9 +1459,11 @@ async def _record_task_comment_from_update(
|
|||||||
event_type="task.comment",
|
event_type="task.comment",
|
||||||
message=update.comment,
|
message=update.comment,
|
||||||
task_id=update.task.id,
|
task_id=update.task.id,
|
||||||
agent_id=update.actor.agent.id
|
agent_id=(
|
||||||
|
update.actor.agent.id
|
||||||
if update.actor.actor_type == "agent" and update.actor.agent
|
if update.actor.actor_type == "agent" and update.actor.agent
|
||||||
else None,
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
session.add(event)
|
session.add(event)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -1496,9 +1476,7 @@ async def _record_task_update_activity(
|
|||||||
) -> None:
|
) -> None:
|
||||||
event_type, message = _task_event_details(update.task, update.previous_status)
|
event_type, message = _task_event_details(update.task, update.previous_status)
|
||||||
actor_agent_id = (
|
actor_agent_id = (
|
||||||
update.actor.agent.id
|
update.actor.agent.id if update.actor.actor_type == "agent" and update.actor.agent else None
|
||||||
if update.actor.actor_type == "agent" and update.actor.agent
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
record_activity(
|
record_activity(
|
||||||
session,
|
session,
|
||||||
@@ -1525,10 +1503,7 @@ async def _notify_task_update_assignment_changes(
|
|||||||
if (
|
if (
|
||||||
update.task.status == "inbox"
|
update.task.status == "inbox"
|
||||||
and update.task.assigned_agent_id is None
|
and update.task.assigned_agent_id is None
|
||||||
and (
|
and (update.previous_status != "inbox" or update.previous_assigned is not None)
|
||||||
update.previous_status != "inbox"
|
|
||||||
or update.previous_assigned is not None
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
board = (
|
board = (
|
||||||
await Board.objects.by_id(update.task.board_id).first(session)
|
await Board.objects.by_id(update.task.board_id).first(session)
|
||||||
|
|||||||
@@ -77,10 +77,7 @@ async def _touch_agent_presence(
|
|||||||
real activity even if the heartbeat loop isn't running.
|
real activity even if the heartbeat loop isn't running.
|
||||||
"""
|
"""
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
if (
|
if agent.last_seen_at is not None and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL:
|
||||||
agent.last_seen_at is not None
|
|
||||||
and now - agent.last_seen_at < _LAST_SEEN_TOUCH_INTERVAL
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
agent.last_seen_at = now
|
agent.last_seen_at = now
|
||||||
|
|||||||
@@ -19,9 +19,7 @@ _MULTIPLIERS: dict[str, int] = {
|
|||||||
_MAX_SCHEDULE_SECONDS = 60 * 60 * 24 * 365 * 10
|
_MAX_SCHEDULE_SECONDS = 60 * 60 * 24 * 365 * 10
|
||||||
|
|
||||||
_ERR_SCHEDULE_REQUIRED = "schedule is required"
|
_ERR_SCHEDULE_REQUIRED = "schedule is required"
|
||||||
_ERR_SCHEDULE_INVALID = (
|
_ERR_SCHEDULE_INVALID = 'Invalid schedule. Expected format like "10m", "1h", "2d", "1w".'
|
||||||
'Invalid schedule. Expected format like "10m", "1h", "2d", "1w".'
|
|
||||||
)
|
|
||||||
_ERR_SCHEDULE_NONPOSITIVE = "Schedule must be greater than 0."
|
_ERR_SCHEDULE_NONPOSITIVE = "Schedule must be greater than 0."
|
||||||
_ERR_SCHEDULE_TOO_LARGE = "Schedule is too large (max 10 years)."
|
_ERR_SCHEDULE_TOO_LARGE = "Schedule is too large (max 10 years)."
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from fastapi.responses import JSONResponse
|
|||||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -44,9 +44,7 @@ class RequestIdMiddleware:
|
|||||||
if message["type"] == "http.response.start":
|
if message["type"] == "http.response.start":
|
||||||
# Starlette uses `list[tuple[bytes, bytes]]` here.
|
# Starlette uses `list[tuple[bytes, bytes]]` here.
|
||||||
headers: list[tuple[bytes, bytes]] = message.setdefault("headers", [])
|
headers: list[tuple[bytes, bytes]] = message.setdefault("headers", [])
|
||||||
if not any(
|
if not any(key.lower() == self._header_name_bytes for key, _ in headers):
|
||||||
key.lower() == self._header_name_bytes for key, _ in headers
|
|
||||||
):
|
|
||||||
request_id_bytes = request_id.encode("latin-1")
|
request_id_bytes = request_id.encode("latin-1")
|
||||||
headers.append((self._header_name_bytes, request_id_bytes))
|
headers.append((self._header_name_bytes, request_id_bytes))
|
||||||
await send(message)
|
await send(message)
|
||||||
|
|||||||
@@ -198,8 +198,7 @@ class AppLogger:
|
|||||||
formatter: logging.Formatter = JsonFormatter()
|
formatter: logging.Formatter = JsonFormatter()
|
||||||
else:
|
else:
|
||||||
formatter = KeyValueFormatter(
|
formatter = KeyValueFormatter(
|
||||||
"%(asctime)s %(levelname)s %(name)s %(message)s "
|
"%(asctime)s %(levelname)s %(name)s %(message)s " "app=%(app)s version=%(version)s",
|
||||||
"app=%(app)s version=%(version)s",
|
|
||||||
)
|
)
|
||||||
if settings.log_use_utc:
|
if settings.log_use_utc:
|
||||||
formatter.converter = time.gmtime
|
formatter.converter = time.gmtime
|
||||||
|
|||||||
@@ -161,10 +161,7 @@ async def list_by(
|
|||||||
|
|
||||||
async def exists(session: AsyncSession, model: type[ModelT], **lookup: object) -> bool:
|
async def exists(session: AsyncSession, model: type[ModelT], **lookup: object) -> bool:
|
||||||
"""Return whether any object exists for lookup values."""
|
"""Return whether any object exists for lookup values."""
|
||||||
return (
|
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
|
||||||
(await session.exec(_lookup_statement(model, lookup).limit(1))).first()
|
|
||||||
is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _criteria_statement(
|
def _criteria_statement(
|
||||||
@@ -219,11 +216,7 @@ async def update_where(
|
|||||||
commit = bool(options.pop("commit", False))
|
commit = bool(options.pop("commit", False))
|
||||||
exclude_none = bool(options.pop("exclude_none", False))
|
exclude_none = bool(options.pop("exclude_none", False))
|
||||||
allowed_fields_raw = options.pop("allowed_fields", None)
|
allowed_fields_raw = options.pop("allowed_fields", None)
|
||||||
allowed_fields = (
|
allowed_fields = allowed_fields_raw if isinstance(allowed_fields_raw, set) else None
|
||||||
allowed_fields_raw
|
|
||||||
if isinstance(allowed_fields_raw, set)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
source_updates: dict[str, Any] = {}
|
source_updates: dict[str, Any] = {}
|
||||||
if updates:
|
if updates:
|
||||||
source_updates.update(dict(updates))
|
source_updates.update(dict(updates))
|
||||||
@@ -276,11 +269,7 @@ async def patch(
|
|||||||
"""Apply partial updates and persist object."""
|
"""Apply partial updates and persist object."""
|
||||||
exclude_none = bool(options.pop("exclude_none", False))
|
exclude_none = bool(options.pop("exclude_none", False))
|
||||||
allowed_fields_raw = options.pop("allowed_fields", None)
|
allowed_fields_raw = options.pop("allowed_fields", None)
|
||||||
allowed_fields = (
|
allowed_fields = allowed_fields_raw if isinstance(allowed_fields_raw, set) else None
|
||||||
allowed_fields_raw
|
|
||||||
if isinstance(allowed_fields_raw, set)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
commit = bool(options.pop("commit", True))
|
commit = bool(options.pop("commit", True))
|
||||||
refresh = bool(options.pop("refresh", True))
|
refresh = bool(options.pop("refresh", True))
|
||||||
apply_updates(
|
apply_updates(
|
||||||
|
|||||||
@@ -36,12 +36,8 @@ class BoardOnboardingConfirm(SQLModel):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_goal_fields(self) -> Self:
|
def validate_goal_fields(self) -> Self:
|
||||||
"""Require goal metadata when the board type is `goal`."""
|
"""Require goal metadata when the board type is `goal`."""
|
||||||
if self.board_type == "goal" and (
|
if self.board_type == "goal" and (not self.objective or not self.success_metrics):
|
||||||
not self.objective or not self.success_metrics
|
message = "Confirmed goal boards require objective and success_metrics"
|
||||||
):
|
|
||||||
message = (
|
|
||||||
"Confirmed goal boards require objective and success_metrics"
|
|
||||||
)
|
|
||||||
raise ValueError(message)
|
raise ValueError(message)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
@@ -9,9 +9,7 @@ from uuid import UUID
|
|||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
_ERR_GOAL_FIELDS_REQUIRED = (
|
_ERR_GOAL_FIELDS_REQUIRED = "Confirmed goal boards require objective and success_metrics"
|
||||||
"Confirmed goal boards require objective and success_metrics"
|
|
||||||
)
|
|
||||||
_ERR_GATEWAY_REQUIRED = "gateway_id is required"
|
_ERR_GATEWAY_REQUIRED = "gateway_id is required"
|
||||||
RUNTIME_ANNOTATION_TYPES = (datetime, UUID)
|
RUNTIME_ANNOTATION_TYPES = (datetime, UUID)
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,7 @@ from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoes
|
|||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import (
|
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, openclaw_call
|
||||||
OpenClawGatewayError,
|
|
||||||
ensure_session,
|
|
||||||
openclaw_call,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.models.agents import Agent
|
from app.models.agents import Agent
|
||||||
@@ -414,7 +410,9 @@ async def _supported_gateway_files(config: GatewayClientConfig) -> set[str]:
|
|||||||
if not agent_id:
|
if not agent_id:
|
||||||
return set(DEFAULT_GATEWAY_FILES)
|
return set(DEFAULT_GATEWAY_FILES)
|
||||||
files_payload = await openclaw_call(
|
files_payload = await openclaw_call(
|
||||||
"agents.files.list", {"agentId": agent_id}, config=config,
|
"agents.files.list",
|
||||||
|
{"agentId": agent_id},
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
if isinstance(files_payload, dict):
|
if isinstance(files_payload, dict):
|
||||||
files = files_payload.get("files") or []
|
files = files_payload.get("files") or []
|
||||||
@@ -438,11 +436,14 @@ async def _reset_session(session_key: str, config: GatewayClientConfig) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def _gateway_agent_files_index(
|
async def _gateway_agent_files_index(
|
||||||
agent_id: str, config: GatewayClientConfig,
|
agent_id: str,
|
||||||
|
config: GatewayClientConfig,
|
||||||
) -> dict[str, dict[str, Any]]:
|
) -> dict[str, dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
payload = await openclaw_call(
|
payload = await openclaw_call(
|
||||||
"agents.files.list", {"agentId": agent_id}, config=config,
|
"agents.files.list",
|
||||||
|
{"agentId": agent_id},
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
if isinstance(payload, dict):
|
if isinstance(payload, dict):
|
||||||
files = payload.get("files") or []
|
files = payload.get("files") or []
|
||||||
@@ -486,18 +487,14 @@ def _render_agent_files(
|
|||||||
)
|
)
|
||||||
heartbeat_path = _templates_root() / heartbeat_template
|
heartbeat_path = _templates_root() / heartbeat_template
|
||||||
if heartbeat_path.exists():
|
if heartbeat_path.exists():
|
||||||
rendered[name] = (
|
rendered[name] = env.get_template(heartbeat_template).render(**context).strip()
|
||||||
env.get_template(heartbeat_template).render(**context).strip()
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
override = overrides.get(name)
|
override = overrides.get(name)
|
||||||
if override:
|
if override:
|
||||||
rendered[name] = env.from_string(override).render(**context).strip()
|
rendered[name] = env.from_string(override).render(**context).strip()
|
||||||
continue
|
continue
|
||||||
template_name = (
|
template_name = (
|
||||||
template_overrides[name]
|
template_overrides[name] if template_overrides and name in template_overrides else name
|
||||||
if template_overrides and name in template_overrides
|
|
||||||
else name
|
|
||||||
)
|
)
|
||||||
path = _templates_root() / template_name
|
path = _templates_root() / template_name
|
||||||
if path.exists():
|
if path.exists():
|
||||||
@@ -596,8 +593,7 @@ def _heartbeat_entry_map(
|
|||||||
entries: list[tuple[str, str, dict[str, Any]]],
|
entries: list[tuple[str, str, dict[str, Any]]],
|
||||||
) -> dict[str, tuple[str, dict[str, Any]]]:
|
) -> dict[str, tuple[str, dict[str, Any]]]:
|
||||||
return {
|
return {
|
||||||
agent_id: (workspace_path, heartbeat)
|
agent_id: (workspace_path, heartbeat) for agent_id, workspace_path, heartbeat in entries
|
||||||
for agent_id, workspace_path, heartbeat in entries
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -694,9 +690,7 @@ async def _remove_gateway_agent_list(
|
|||||||
raise OpenClawGatewayError(msg)
|
raise OpenClawGatewayError(msg)
|
||||||
|
|
||||||
new_list = [
|
new_list = [
|
||||||
entry
|
entry for entry in lst if not (isinstance(entry, dict) and entry.get("id") == agent_id)
|
||||||
for entry in lst
|
|
||||||
if not (isinstance(entry, dict) and entry.get("id") == agent_id)
|
|
||||||
]
|
]
|
||||||
if len(new_list) == len(lst):
|
if len(new_list) == len(lst):
|
||||||
return
|
return
|
||||||
@@ -841,7 +835,9 @@ async def provision_main_agent(
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
client_config = GatewayClientConfig(url=gateway.url, token=gateway.token)
|
||||||
await ensure_session(
|
await ensure_session(
|
||||||
gateway.main_session_key, config=client_config, label="Main Agent",
|
gateway.main_session_key,
|
||||||
|
config=client_config,
|
||||||
|
label="Main Agent",
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_id = await _gateway_default_agent_id(
|
agent_id = await _gateway_default_agent_id(
|
||||||
|
|||||||
@@ -38,10 +38,7 @@ def _status_weight_expr() -> ColumnElement[int]:
|
|||||||
|
|
||||||
def _priority_weight_expr() -> ColumnElement[int]:
|
def _priority_weight_expr() -> ColumnElement[int]:
|
||||||
"""Return a SQL expression that sorts task priorities by configured order."""
|
"""Return a SQL expression that sorts task priorities by configured order."""
|
||||||
whens = [
|
whens = [(col(Task.priority) == key, weight) for key, weight in _PRIORITY_ORDER.items()]
|
||||||
(col(Task.priority) == key, weight)
|
|
||||||
for key, weight in _PRIORITY_ORDER.items()
|
|
||||||
]
|
|
||||||
return case(*whens, else_=99)
|
return case(*whens, else_=99)
|
||||||
|
|
||||||
|
|
||||||
@@ -106,11 +103,7 @@ async def _agent_names(
|
|||||||
tasks: list[Task],
|
tasks: list[Task],
|
||||||
) -> dict[UUID, str]:
|
) -> dict[UUID, str]:
|
||||||
"""Return agent names keyed by assigned agent ids in the provided tasks."""
|
"""Return agent names keyed by assigned agent ids in the provided tasks."""
|
||||||
assigned_ids = {
|
assigned_ids = {task.assigned_agent_id for task in tasks if task.assigned_agent_id is not None}
|
||||||
task.assigned_agent_id
|
|
||||||
for task in tasks
|
|
||||||
if task.assigned_agent_id is not None
|
|
||||||
}
|
|
||||||
if not assigned_ids:
|
if not assigned_ids:
|
||||||
return {}
|
return {}
|
||||||
return dict(
|
return dict(
|
||||||
|
|||||||
@@ -10,11 +10,7 @@ from sqlmodel import col, select
|
|||||||
from app.core.agent_tokens import generate_agent_token, hash_agent_token
|
from app.core.agent_tokens import generate_agent_token, hash_agent_token
|
||||||
from app.core.time import utcnow
|
from app.core.time import utcnow
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import (
|
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||||
OpenClawGatewayError,
|
|
||||||
ensure_session,
|
|
||||||
send_message,
|
|
||||||
)
|
|
||||||
from app.models.agents import Agent
|
from app.models.agents import Agent
|
||||||
from app.services.agent_provisioning import (
|
from app.services.agent_provisioning import (
|
||||||
DEFAULT_HEARTBEAT_CONFIG,
|
DEFAULT_HEARTBEAT_CONFIG,
|
||||||
|
|||||||
@@ -55,8 +55,7 @@ def _agent_to_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
|
|||||||
model = AgentRead.model_validate(agent, from_attributes=True)
|
model = AgentRead.model_validate(agent, from_attributes=True)
|
||||||
computed_status = _computed_agent_status(agent)
|
computed_status = _computed_agent_status(agent)
|
||||||
is_gateway_main = bool(
|
is_gateway_main = bool(
|
||||||
agent.openclaw_session_id
|
agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys,
|
||||||
and agent.openclaw_session_id in main_session_keys,
|
|
||||||
)
|
)
|
||||||
return model.model_copy(
|
return model.model_copy(
|
||||||
update={
|
update={
|
||||||
@@ -84,11 +83,7 @@ def _task_to_card(
|
|||||||
) -> TaskCardRead:
|
) -> TaskCardRead:
|
||||||
card = TaskCardRead.model_validate(task, from_attributes=True)
|
card = TaskCardRead.model_validate(task, from_attributes=True)
|
||||||
approvals_count, approvals_pending_count = counts_by_task_id.get(task.id, (0, 0))
|
approvals_count, approvals_pending_count = counts_by_task_id.get(task.id, (0, 0))
|
||||||
assignee = (
|
assignee = agent_name_by_id.get(task.assigned_agent_id) if task.assigned_agent_id else None
|
||||||
agent_name_by_id.get(task.assigned_agent_id)
|
|
||||||
if task.assigned_agent_id
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
depends_on_task_ids = deps_by_task_id.get(task.id, [])
|
depends_on_task_ids = deps_by_task_id.get(task.id, [])
|
||||||
blocked_by_task_ids = blocked_by_dependency_ids(
|
blocked_by_task_ids = blocked_by_dependency_ids(
|
||||||
dependency_ids=depends_on_task_ids,
|
dependency_ids=depends_on_task_ids,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from app.models.agents import Agent
|
from app.models.agents import Agent
|
||||||
|
|
||||||
# Mention tokens are single, space-free words (e.g. "@alex", "@lead").
|
# Mention tokens are single, space-free words (e.g. "@alex", "@lead").
|
||||||
|
|||||||
@@ -79,7 +79,8 @@ async def get_member(
|
|||||||
|
|
||||||
|
|
||||||
async def get_first_membership(
|
async def get_first_membership(
|
||||||
session: AsyncSession, user_id: UUID,
|
session: AsyncSession,
|
||||||
|
user_id: UUID,
|
||||||
) -> OrganizationMember | None:
|
) -> OrganizationMember | None:
|
||||||
"""Return the oldest membership for a user, if any."""
|
"""Return the oldest membership for a user, if any."""
|
||||||
return (
|
return (
|
||||||
@@ -99,7 +100,8 @@ async def set_active_organization(
|
|||||||
member = await get_member(session, user_id=user.id, organization_id=organization_id)
|
member = await get_member(session, user_id=user.id, organization_id=organization_id)
|
||||||
if member is None:
|
if member is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="No org access",
|
||||||
)
|
)
|
||||||
if user.active_organization_id != organization_id:
|
if user.active_organization_id != organization_id:
|
||||||
user.active_organization_id = organization_id
|
user.active_organization_id = organization_id
|
||||||
@@ -177,8 +179,7 @@ async def accept_invite(
|
|||||||
access_rows = list(
|
access_rows = list(
|
||||||
await session.exec(
|
await session.exec(
|
||||||
select(OrganizationInviteBoardAccess).where(
|
select(OrganizationInviteBoardAccess).where(
|
||||||
col(OrganizationInviteBoardAccess.organization_invite_id)
|
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
|
||||||
== invite.id,
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -207,7 +208,8 @@ async def accept_invite(
|
|||||||
|
|
||||||
|
|
||||||
async def ensure_member_for_user(
|
async def ensure_member_for_user(
|
||||||
session: AsyncSession, user: User,
|
session: AsyncSession,
|
||||||
|
user: User,
|
||||||
) -> OrganizationMember:
|
) -> OrganizationMember:
|
||||||
"""Ensure a user has some membership, creating one if necessary."""
|
"""Ensure a user has some membership, creating one if necessary."""
|
||||||
existing = await get_active_membership(session, user)
|
existing = await get_active_membership(session, user)
|
||||||
@@ -291,21 +293,27 @@ async def require_board_access(
|
|||||||
) -> OrganizationMember:
|
) -> OrganizationMember:
|
||||||
"""Require board access for a user and return matching membership."""
|
"""Require board access for a user and return matching membership."""
|
||||||
member = await get_member(
|
member = await get_member(
|
||||||
session, user_id=user.id, organization_id=board.organization_id,
|
session,
|
||||||
|
user_id=user.id,
|
||||||
|
organization_id=board.organization_id,
|
||||||
)
|
)
|
||||||
if member is None:
|
if member is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail="No org access",
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="No org access",
|
||||||
)
|
)
|
||||||
if not await has_board_access(session, member=member, board=board, write=write):
|
if not await has_board_access(session, member=member, board=board, write=write):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail="Board access denied",
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Board access denied",
|
||||||
)
|
)
|
||||||
return member
|
return member
|
||||||
|
|
||||||
|
|
||||||
def board_access_filter(
|
def board_access_filter(
|
||||||
member: OrganizationMember, *, write: bool,
|
member: OrganizationMember,
|
||||||
|
*,
|
||||||
|
write: bool,
|
||||||
) -> ColumnElement[bool]:
|
) -> ColumnElement[bool]:
|
||||||
"""Build a SQL filter expression for boards visible to a member."""
|
"""Build a SQL filter expression for boards visible to a member."""
|
||||||
if write and member_all_boards_write(member):
|
if write and member_all_boards_write(member):
|
||||||
|
|||||||
@@ -42,10 +42,7 @@ class SoulRef:
|
|||||||
def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
|
def _parse_sitemap_soul_refs(sitemap_xml: str) -> list[SoulRef]:
|
||||||
"""Parse sitemap XML and extract valid souls.directory handle/slug refs."""
|
"""Parse sitemap XML and extract valid souls.directory handle/slug refs."""
|
||||||
# Extract <loc> values without XML entity expansion.
|
# Extract <loc> values without XML entity expansion.
|
||||||
urls = [
|
urls = [unescape(match.group(1)).strip() for match in _LOC_PATTERN.finditer(sitemap_xml)]
|
||||||
unescape(match.group(1)).strip()
|
|
||||||
for match in _LOC_PATTERN.finditer(sitemap_xml)
|
|
||||||
]
|
|
||||||
|
|
||||||
refs: list[SoulRef] = []
|
refs: list[SoulRef] = []
|
||||||
for url in urls:
|
for url in urls:
|
||||||
@@ -110,10 +107,7 @@ async def fetch_soul_markdown(
|
|||||||
normalized_slug = slug.strip().strip("/")
|
normalized_slug = slug.strip().strip("/")
|
||||||
if normalized_slug.endswith(".md"):
|
if normalized_slug.endswith(".md"):
|
||||||
normalized_slug = normalized_slug[: -len(".md")]
|
normalized_slug = normalized_slug[: -len(".md")]
|
||||||
url = (
|
url = f"{SOULS_DIRECTORY_BASE_URL}/api/souls/" f"{normalized_handle}/{normalized_slug}.md"
|
||||||
f"{SOULS_DIRECTORY_BASE_URL}/api/souls/"
|
|
||||||
f"{normalized_handle}/{normalized_slug}.md"
|
|
||||||
)
|
|
||||||
|
|
||||||
owns_client = client is None
|
owns_client = client is None
|
||||||
if client is None:
|
if client is None:
|
||||||
|
|||||||
@@ -79,11 +79,7 @@ def blocked_by_dependency_ids(
|
|||||||
status_by_id: Mapping[UUID, str],
|
status_by_id: Mapping[UUID, str],
|
||||||
) -> list[UUID]:
|
) -> list[UUID]:
|
||||||
"""Return dependency ids that are not yet in the done status."""
|
"""Return dependency ids that are not yet in the done status."""
|
||||||
return [
|
return [dep_id for dep_id in dependency_ids if status_by_id.get(dep_id) != DONE_STATUS]
|
||||||
dep_id
|
|
||||||
for dep_id in dependency_ids
|
|
||||||
if status_by_id.get(dep_id) != DONE_STATUS
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def blocked_by_for_task(
|
async def blocked_by_for_task(
|
||||||
|
|||||||
@@ -14,11 +14,7 @@ from sqlalchemy import func
|
|||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.agent_tokens import (
|
from app.core.agent_tokens import generate_agent_token, hash_agent_token, verify_agent_token
|
||||||
generate_agent_token,
|
|
||||||
hash_agent_token,
|
|
||||||
verify_agent_token,
|
|
||||||
)
|
|
||||||
from app.core.time import utcnow
|
from app.core.time import utcnow
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, openclaw_call
|
from app.integrations.openclaw_gateway import OpenClawGatewayError, openclaw_call
|
||||||
@@ -108,10 +104,7 @@ def _is_transient_gateway_error(exc: Exception) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _gateway_timeout_message(exc: OpenClawGatewayError) -> str:
|
def _gateway_timeout_message(exc: OpenClawGatewayError) -> str:
|
||||||
return (
|
return "Gateway unreachable after 10 minutes (template sync timeout). " f"Last error: {exc}"
|
||||||
"Gateway unreachable after 10 minutes (template sync timeout). "
|
|
||||||
f"Last error: {exc}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _GatewayBackoff:
|
class _GatewayBackoff:
|
||||||
@@ -375,6 +368,7 @@ async def _rotate_agent_token(session: AsyncSession, agent: Agent) -> str:
|
|||||||
|
|
||||||
async def _ping_gateway(ctx: _SyncContext, result: GatewayTemplatesSyncResult) -> bool:
|
async def _ping_gateway(ctx: _SyncContext, result: GatewayTemplatesSyncResult) -> bool:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
async def _do_ping() -> object:
|
async def _do_ping() -> object:
|
||||||
return await openclaw_call("agents.list", config=ctx.config)
|
return await openclaw_call("agents.list", config=ctx.config)
|
||||||
|
|
||||||
@@ -486,6 +480,7 @@ async def _sync_one_agent(
|
|||||||
if not auth_token:
|
if not auth_token:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
|
|
||||||
async def _do_provision() -> None:
|
async def _do_provision() -> None:
|
||||||
await provision_agent(
|
await provision_agent(
|
||||||
agent,
|
agent,
|
||||||
@@ -533,10 +528,7 @@ async def _sync_main_agent(
|
|||||||
if main_agent is None:
|
if main_agent is None:
|
||||||
_append_sync_error(
|
_append_sync_error(
|
||||||
result,
|
result,
|
||||||
message=(
|
message=("Gateway main agent record not found; " "skipping main agent template sync."),
|
||||||
"Gateway main agent record not found; "
|
|
||||||
"skipping main agent template sync."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
@@ -574,6 +566,7 @@ async def _sync_main_agent(
|
|||||||
return True
|
return True
|
||||||
stop_sync = False
|
stop_sync = False
|
||||||
try:
|
try:
|
||||||
|
|
||||||
async def _do_provision_main() -> None:
|
async def _do_provision_main() -> None:
|
||||||
await provision_main_agent(
|
await provision_main_agent(
|
||||||
main_agent,
|
main_agent,
|
||||||
|
|||||||
@@ -32,17 +32,13 @@ def _parse_args() -> argparse.Namespace:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--reset-sessions",
|
"--reset-sessions",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help=(
|
help=("Reset agent sessions after syncing files " "(forces agents to re-read workspace)"),
|
||||||
"Reset agent sessions after syncing files "
|
|
||||||
"(forces agents to re-read workspace)"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rotate-tokens",
|
"--rotate-tokens",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help=(
|
help=(
|
||||||
"Rotate agent tokens when TOOLS.md is missing/unreadable "
|
"Rotate agent tokens when TOOLS.md is missing/unreadable " "or token drift is detected"
|
||||||
"or token drift is detected"
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -56,10 +52,7 @@ def _parse_args() -> argparse.Namespace:
|
|||||||
async def _run() -> int:
|
async def _run() -> int:
|
||||||
from app.db.session import async_session_maker
|
from app.db.session import async_session_maker
|
||||||
from app.models.gateways import Gateway
|
from app.models.gateways import Gateway
|
||||||
from app.services.template_sync import (
|
from app.services.template_sync import GatewayTemplateSyncOptions, sync_gateway_templates
|
||||||
GatewayTemplateSyncOptions,
|
|
||||||
sync_gateway_templates,
|
|
||||||
)
|
|
||||||
|
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
gateway_id = UUID(args.gateway_id)
|
gateway_id = UUID(args.gateway_id)
|
||||||
@@ -86,8 +79,7 @@ async def _run() -> int:
|
|||||||
|
|
||||||
sys.stdout.write(f"gateway_id={result.gateway_id}\n")
|
sys.stdout.write(f"gateway_id={result.gateway_id}\n")
|
||||||
sys.stdout.write(
|
sys.stdout.write(
|
||||||
f"include_main={result.include_main} "
|
f"include_main={result.include_main} " f"reset_sessions={result.reset_sessions}\n",
|
||||||
f"reset_sessions={result.reset_sessions}\n",
|
|
||||||
)
|
)
|
||||||
sys.stdout.write(
|
sys.stdout.write(
|
||||||
f"agents_updated={result.agents_updated} "
|
f"agents_updated={result.agents_updated} "
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -11,6 +12,9 @@ from app.core.error_handling import (
|
|||||||
REQUEST_ID_HEADER,
|
REQUEST_ID_HEADER,
|
||||||
_error_payload,
|
_error_payload,
|
||||||
_get_request_id,
|
_get_request_id,
|
||||||
|
_http_exception_exception_handler,
|
||||||
|
_request_validation_exception_handler,
|
||||||
|
_response_validation_exception_handler,
|
||||||
install_error_handling,
|
install_error_handling,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -123,3 +127,24 @@ def test_get_request_id_returns_none_for_missing_or_invalid_state() -> None:
|
|||||||
|
|
||||||
def test_error_payload_omits_request_id_when_none() -> None:
|
def test_error_payload_omits_request_id_when_none() -> None:
|
||||||
assert _error_payload(detail="x", request_id=None) == {"detail": "x"}
|
assert _error_payload(detail="x", request_id=None) == {"detail": "x"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_validation_exception_wrapper_rejects_wrong_exception() -> None:
|
||||||
|
req = Request({"type": "http", "headers": [], "state": {}})
|
||||||
|
with pytest.raises(TypeError, match="Expected RequestValidationError"):
|
||||||
|
await _request_validation_exception_handler(req, Exception("x"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_response_validation_exception_wrapper_rejects_wrong_exception() -> None:
|
||||||
|
req = Request({"type": "http", "headers": [], "state": {}})
|
||||||
|
with pytest.raises(TypeError, match="Expected ResponseValidationError"):
|
||||||
|
await _response_validation_exception_handler(req, Exception("x"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_exception_wrapper_rejects_wrong_exception() -> None:
|
||||||
|
req = Request({"type": "http", "headers": [], "state": {}})
|
||||||
|
with pytest.raises(TypeError, match="Expected StarletteHTTPException"):
|
||||||
|
await _http_exception_exception_handler(req, Exception("x"))
|
||||||
|
|||||||
@@ -16,10 +16,7 @@ from app.models.organization_invites import OrganizationInvite
|
|||||||
from app.models.organization_members import OrganizationMember
|
from app.models.organization_members import OrganizationMember
|
||||||
from app.models.organizations import Organization
|
from app.models.organizations import Organization
|
||||||
from app.models.users import User
|
from app.models.users import User
|
||||||
from app.schemas.organizations import (
|
from app.schemas.organizations import OrganizationBoardAccessSpec, OrganizationMemberAccessUpdate
|
||||||
OrganizationBoardAccessSpec,
|
|
||||||
OrganizationMemberAccessUpdate,
|
|
||||||
)
|
|
||||||
from app.services import organizations
|
from app.services import organizations
|
||||||
|
|
||||||
|
|
||||||
@@ -87,10 +84,7 @@ class _FakeSession:
|
|||||||
|
|
||||||
|
|
||||||
def test_normalize_invited_email_strips_and_lowercases() -> None:
|
def test_normalize_invited_email_strips_and_lowercases() -> None:
|
||||||
assert (
|
assert organizations.normalize_invited_email(" Foo@Example.com ") == "foo@example.com"
|
||||||
organizations.normalize_invited_email(" Foo@Example.com ")
|
|
||||||
== "foo@example.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -128,7 +122,9 @@ async def test_ensure_member_for_user_returns_existing_membership(
|
|||||||
) -> None:
|
) -> None:
|
||||||
user = User(clerk_user_id="u1")
|
user = User(clerk_user_id="u1")
|
||||||
existing = OrganizationMember(
|
existing = OrganizationMember(
|
||||||
organization_id=uuid4(), user_id=user.id, role="member",
|
organization_id=uuid4(),
|
||||||
|
user_id=user.id,
|
||||||
|
role="member",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _fake_get_active(_session: Any, _user: User) -> OrganizationMember:
|
async def _fake_get_active(_session: Any, _user: User) -> OrganizationMember:
|
||||||
@@ -161,11 +157,15 @@ async def test_ensure_member_for_user_accepts_pending_invite(
|
|||||||
return invite
|
return invite
|
||||||
|
|
||||||
accepted = OrganizationMember(
|
accepted = OrganizationMember(
|
||||||
organization_id=org_id, user_id=user.id, role="member",
|
organization_id=org_id,
|
||||||
|
user_id=user.id,
|
||||||
|
role="member",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _fake_accept(
|
async def _fake_accept(
|
||||||
_session: Any, _invite: OrganizationInvite, _user: User,
|
_session: Any,
|
||||||
|
_invite: OrganizationInvite,
|
||||||
|
_user: User,
|
||||||
) -> OrganizationMember:
|
) -> OrganizationMember:
|
||||||
assert _invite is invite
|
assert _invite is invite
|
||||||
assert _user is user
|
assert _user is user
|
||||||
@@ -216,7 +216,10 @@ async def test_has_board_access_denies_cross_org() -> None:
|
|||||||
board = Board(id=uuid4(), organization_id=uuid4(), name="b", slug="b")
|
board = Board(id=uuid4(), organization_id=uuid4(), name="b", slug="b")
|
||||||
assert (
|
assert (
|
||||||
await organizations.has_board_access(
|
await organizations.has_board_access(
|
||||||
session, member=member, board=board, write=False,
|
session,
|
||||||
|
member=member,
|
||||||
|
board=board,
|
||||||
|
write=False,
|
||||||
)
|
)
|
||||||
is False
|
is False
|
||||||
)
|
)
|
||||||
@@ -226,7 +229,10 @@ async def test_has_board_access_denies_cross_org() -> None:
|
|||||||
async def test_has_board_access_uses_org_board_access_row_read_and_write() -> None:
|
async def test_has_board_access_uses_org_board_access_row_read_and_write() -> None:
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
member = OrganizationMember(
|
member = OrganizationMember(
|
||||||
id=uuid4(), organization_id=org_id, user_id=uuid4(), role="member",
|
id=uuid4(),
|
||||||
|
organization_id=org_id,
|
||||||
|
user_id=uuid4(),
|
||||||
|
role="member",
|
||||||
)
|
)
|
||||||
board = Board(id=uuid4(), organization_id=org_id, name="b", slug="b")
|
board = Board(id=uuid4(), organization_id=org_id, name="b", slug="b")
|
||||||
|
|
||||||
@@ -239,7 +245,10 @@ async def test_has_board_access_uses_org_board_access_row_read_and_write() -> No
|
|||||||
session = _FakeSession(exec_results=[_FakeExecResult(first_value=access)])
|
session = _FakeSession(exec_results=[_FakeExecResult(first_value=access)])
|
||||||
assert (
|
assert (
|
||||||
await organizations.has_board_access(
|
await organizations.has_board_access(
|
||||||
session, member=member, board=board, write=False,
|
session,
|
||||||
|
member=member,
|
||||||
|
board=board,
|
||||||
|
write=False,
|
||||||
)
|
)
|
||||||
is True
|
is True
|
||||||
)
|
)
|
||||||
@@ -253,7 +262,10 @@ async def test_has_board_access_uses_org_board_access_row_read_and_write() -> No
|
|||||||
session2 = _FakeSession(exec_results=[_FakeExecResult(first_value=access2)])
|
session2 = _FakeSession(exec_results=[_FakeExecResult(first_value=access2)])
|
||||||
assert (
|
assert (
|
||||||
await organizations.has_board_access(
|
await organizations.has_board_access(
|
||||||
session2, member=member, board=board, write=False,
|
session2,
|
||||||
|
member=member,
|
||||||
|
board=board,
|
||||||
|
write=False,
|
||||||
)
|
)
|
||||||
is True
|
is True
|
||||||
)
|
)
|
||||||
@@ -267,7 +279,10 @@ async def test_has_board_access_uses_org_board_access_row_read_and_write() -> No
|
|||||||
session3 = _FakeSession(exec_results=[_FakeExecResult(first_value=access3)])
|
session3 = _FakeSession(exec_results=[_FakeExecResult(first_value=access3)])
|
||||||
assert (
|
assert (
|
||||||
await organizations.has_board_access(
|
await organizations.has_board_access(
|
||||||
session3, member=member, board=board, write=True,
|
session3,
|
||||||
|
member=member,
|
||||||
|
board=board,
|
||||||
|
write=True,
|
||||||
)
|
)
|
||||||
is False
|
is False
|
||||||
)
|
)
|
||||||
@@ -288,7 +303,10 @@ async def test_require_board_access_raises_when_no_member(
|
|||||||
session = _FakeSession(exec_results=[])
|
session = _FakeSession(exec_results=[])
|
||||||
with pytest.raises(HTTPException) as exc:
|
with pytest.raises(HTTPException) as exc:
|
||||||
await organizations.require_board_access(
|
await organizations.require_board_access(
|
||||||
session, user=user, board=board, write=False,
|
session,
|
||||||
|
user=user,
|
||||||
|
board=board,
|
||||||
|
write=False,
|
||||||
)
|
)
|
||||||
assert exc.value.status_code == 403
|
assert exc.value.status_code == 403
|
||||||
|
|
||||||
@@ -298,24 +316,33 @@ async def test_apply_member_access_update_deletes_existing_and_adds_rows_when_no
|
|||||||
None
|
None
|
||||||
):
|
):
|
||||||
member = OrganizationMember(
|
member = OrganizationMember(
|
||||||
id=uuid4(), organization_id=uuid4(), user_id=uuid4(), role="member",
|
id=uuid4(),
|
||||||
|
organization_id=uuid4(),
|
||||||
|
user_id=uuid4(),
|
||||||
|
role="member",
|
||||||
)
|
)
|
||||||
update = OrganizationMemberAccessUpdate(
|
update = OrganizationMemberAccessUpdate(
|
||||||
all_boards_read=False,
|
all_boards_read=False,
|
||||||
all_boards_write=False,
|
all_boards_write=False,
|
||||||
board_access=[
|
board_access=[
|
||||||
OrganizationBoardAccessSpec(
|
OrganizationBoardAccessSpec(
|
||||||
board_id=uuid4(), can_read=True, can_write=False,
|
board_id=uuid4(),
|
||||||
|
can_read=True,
|
||||||
|
can_write=False,
|
||||||
),
|
),
|
||||||
OrganizationBoardAccessSpec(
|
OrganizationBoardAccessSpec(
|
||||||
board_id=uuid4(), can_read=True, can_write=True,
|
board_id=uuid4(),
|
||||||
|
can_read=True,
|
||||||
|
can_write=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
session = _FakeSession(exec_results=[])
|
session = _FakeSession(exec_results=[])
|
||||||
|
|
||||||
await organizations.apply_member_access_update(
|
await organizations.apply_member_access_update(
|
||||||
session, member=member, update=update,
|
session,
|
||||||
|
member=member,
|
||||||
|
update=update,
|
||||||
)
|
)
|
||||||
|
|
||||||
# delete statement executed once
|
# delete statement executed once
|
||||||
|
|||||||
Reference in New Issue
Block a user