refactor: enhance docstrings for clarity and consistency across multiple files
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
"""Gateway CRUD and template synchronization endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel import col
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.api.deps import require_org_admin
|
||||
from app.core.agent_tokens import generate_agent_token, hash_agent_token
|
||||
@@ -14,7 +17,11 @@ from app.db import crud
|
||||
from app.db.pagination import paginate
|
||||
from app.db.session import get_session
|
||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||
from app.integrations.openclaw_gateway import (
|
||||
OpenClawGatewayError,
|
||||
ensure_session,
|
||||
send_message,
|
||||
)
|
||||
from app.models.agents import Agent
|
||||
from app.models.gateways import Gateway
|
||||
from app.schemas.common import OkResponse
|
||||
@@ -25,11 +32,61 @@ from app.schemas.gateways import (
|
||||
GatewayUpdate,
|
||||
)
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_main_agent
|
||||
from app.services.organizations import OrganizationContext
|
||||
from app.services.template_sync import sync_gateway_templates as sync_gateway_templates_service
|
||||
from app.services.agent_provisioning import (
|
||||
DEFAULT_HEARTBEAT_CONFIG,
|
||||
provision_main_agent,
|
||||
)
|
||||
from app.services.template_sync import (
|
||||
GatewayTemplateSyncOptions,
|
||||
)
|
||||
from app.services.template_sync import (
|
||||
sync_gateway_templates as sync_gateway_templates_service,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.services.organizations import OrganizationContext
|
||||
|
||||
router = APIRouter(prefix="/gateways", tags=["gateways"])
|
||||
SESSION_DEP = Depends(get_session)
|
||||
AUTH_DEP = Depends(get_auth_context)
|
||||
ORG_ADMIN_DEP = Depends(require_org_admin)
|
||||
INCLUDE_MAIN_QUERY = Query(default=True)
|
||||
RESET_SESSIONS_QUERY = Query(default=False)
|
||||
ROTATE_TOKENS_QUERY = Query(default=False)
|
||||
FORCE_BOOTSTRAP_QUERY = Query(default=False)
|
||||
BOARD_ID_QUERY = Query(default=None)
|
||||
_RUNTIME_TYPE_REFERENCES = (UUID,)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _TemplateSyncQuery:
|
||||
include_main: bool
|
||||
reset_sessions: bool
|
||||
rotate_tokens: bool
|
||||
force_bootstrap: bool
|
||||
board_id: UUID | None
|
||||
|
||||
|
||||
def _template_sync_query(
|
||||
*,
|
||||
include_main: bool = INCLUDE_MAIN_QUERY,
|
||||
reset_sessions: bool = RESET_SESSIONS_QUERY,
|
||||
rotate_tokens: bool = ROTATE_TOKENS_QUERY,
|
||||
force_bootstrap: bool = FORCE_BOOTSTRAP_QUERY,
|
||||
board_id: UUID | None = BOARD_ID_QUERY,
|
||||
) -> _TemplateSyncQuery:
|
||||
return _TemplateSyncQuery(
|
||||
include_main=include_main,
|
||||
reset_sessions=reset_sessions,
|
||||
rotate_tokens=rotate_tokens,
|
||||
force_bootstrap=force_bootstrap,
|
||||
board_id=board_id,
|
||||
)
|
||||
|
||||
|
||||
SYNC_QUERY_DEP = Depends(_template_sync_query)
|
||||
|
||||
|
||||
def _main_agent_name(gateway: Gateway) -> str:
|
||||
@@ -48,7 +105,9 @@ async def _require_gateway(
|
||||
.first(session)
|
||||
)
|
||||
if gateway is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Gateway not found",
|
||||
)
|
||||
return gateway
|
||||
|
||||
|
||||
@@ -59,14 +118,18 @@ async def _find_main_agent(
|
||||
previous_session_key: str | None = None,
|
||||
) -> Agent | None:
|
||||
if gateway.main_session_key:
|
||||
agent = await Agent.objects.filter_by(openclaw_session_id=gateway.main_session_key).first(
|
||||
session
|
||||
agent = await Agent.objects.filter_by(
|
||||
openclaw_session_id=gateway.main_session_key,
|
||||
).first(
|
||||
session,
|
||||
)
|
||||
if agent:
|
||||
return agent
|
||||
if previous_session_key:
|
||||
agent = await Agent.objects.filter_by(openclaw_session_id=previous_session_key).first(
|
||||
session
|
||||
agent = await Agent.objects.filter_by(
|
||||
openclaw_session_id=previous_session_key,
|
||||
).first(
|
||||
session,
|
||||
)
|
||||
if agent:
|
||||
return agent
|
||||
@@ -85,13 +148,17 @@ async def _ensure_main_agent(
|
||||
gateway: Gateway,
|
||||
auth: AuthContext,
|
||||
*,
|
||||
previous_name: str | None = None,
|
||||
previous_session_key: str | None = None,
|
||||
previous: tuple[str | None, str | None] | None = None,
|
||||
action: str = "provision",
|
||||
) -> Agent | None:
|
||||
if not gateway.url or not gateway.main_session_key:
|
||||
return None
|
||||
agent = await _find_main_agent(session, gateway, previous_name, previous_session_key)
|
||||
agent = await _find_main_agent(
|
||||
session,
|
||||
gateway,
|
||||
previous_name=previous[0] if previous else None,
|
||||
previous_session_key=previous[1] if previous else None,
|
||||
)
|
||||
if agent is None:
|
||||
agent = Agent(
|
||||
name=_main_agent_name(gateway),
|
||||
@@ -130,7 +197,8 @@ async def _ensure_main_agent(
|
||||
(
|
||||
f"Hello {agent.name}. Your gateway provisioning was updated.\n\n"
|
||||
"Please re-read AGENTS.md, USER.md, HEARTBEAT.md, and TOOLS.md. "
|
||||
"If BOOTSTRAP.md exists, run it once then delete it. Begin heartbeats after startup."
|
||||
"If BOOTSTRAP.md exists, run it once then delete it. "
|
||||
"Begin heartbeats after startup."
|
||||
),
|
||||
session_key=gateway.main_session_key,
|
||||
config=GatewayClientConfig(url=gateway.url, token=gateway.token),
|
||||
@@ -144,9 +212,10 @@ async def _ensure_main_agent(
|
||||
|
||||
@router.get("", response_model=DefaultLimitOffsetPage[GatewayRead])
|
||||
async def list_gateways(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> DefaultLimitOffsetPage[GatewayRead]:
|
||||
"""List gateways for the caller's organization."""
|
||||
statement = (
|
||||
Gateway.objects.filter_by(organization_id=ctx.organization.id)
|
||||
.order_by(col(Gateway.created_at).desc())
|
||||
@@ -158,10 +227,11 @@ async def list_gateways(
|
||||
@router.post("", response_model=GatewayRead)
|
||||
async def create_gateway(
|
||||
payload: GatewayCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
auth: AuthContext = AUTH_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> Gateway:
|
||||
"""Create a gateway and provision or refresh its main agent."""
|
||||
data = payload.model_dump()
|
||||
data["organization_id"] = ctx.organization.id
|
||||
gateway = await crud.create(session, Gateway, **data)
|
||||
@@ -172,9 +242,10 @@ async def create_gateway(
|
||||
@router.get("/{gateway_id}", response_model=GatewayRead)
|
||||
async def get_gateway(
|
||||
gateway_id: UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> Gateway:
|
||||
"""Return one gateway by id for the caller's organization."""
|
||||
return await _require_gateway(
|
||||
session,
|
||||
gateway_id=gateway_id,
|
||||
@@ -186,10 +257,11 @@ async def get_gateway(
|
||||
async def update_gateway(
|
||||
gateway_id: UUID,
|
||||
payload: GatewayUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
auth: AuthContext = AUTH_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> Gateway:
|
||||
"""Patch a gateway and refresh the main-agent provisioning state."""
|
||||
gateway = await _require_gateway(
|
||||
session,
|
||||
gateway_id=gateway_id,
|
||||
@@ -203,8 +275,7 @@ async def update_gateway(
|
||||
session,
|
||||
gateway,
|
||||
auth,
|
||||
previous_name=previous_name,
|
||||
previous_session_key=previous_session_key,
|
||||
previous=(previous_name, previous_session_key),
|
||||
action="update",
|
||||
)
|
||||
return gateway
|
||||
@@ -213,15 +284,12 @@ async def update_gateway(
|
||||
@router.post("/{gateway_id}/templates/sync", response_model=GatewayTemplatesSyncResult)
|
||||
async def sync_gateway_templates(
|
||||
gateway_id: UUID,
|
||||
include_main: bool = Query(default=True),
|
||||
reset_sessions: bool = Query(default=False),
|
||||
rotate_tokens: bool = Query(default=False),
|
||||
force_bootstrap: bool = Query(default=False),
|
||||
board_id: UUID | None = Query(default=None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
sync_query: _TemplateSyncQuery = SYNC_QUERY_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
auth: AuthContext = AUTH_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> GatewayTemplatesSyncResult:
|
||||
"""Sync templates for a gateway and optionally rotate runtime settings."""
|
||||
gateway = await _require_gateway(
|
||||
session,
|
||||
gateway_id=gateway_id,
|
||||
@@ -230,21 +298,24 @@ async def sync_gateway_templates(
|
||||
return await sync_gateway_templates_service(
|
||||
session,
|
||||
gateway,
|
||||
user=auth.user,
|
||||
include_main=include_main,
|
||||
reset_sessions=reset_sessions,
|
||||
rotate_tokens=rotate_tokens,
|
||||
force_bootstrap=force_bootstrap,
|
||||
board_id=board_id,
|
||||
GatewayTemplateSyncOptions(
|
||||
user=auth.user,
|
||||
include_main=sync_query.include_main,
|
||||
reset_sessions=sync_query.reset_sessions,
|
||||
rotate_tokens=sync_query.rotate_tokens,
|
||||
force_bootstrap=sync_query.force_bootstrap,
|
||||
board_id=sync_query.board_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{gateway_id}", response_model=OkResponse)
|
||||
async def delete_gateway(
|
||||
gateway_id: UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
ctx: OrganizationContext = Depends(require_org_admin),
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> OkResponse:
|
||||
"""Delete a gateway in the caller's organization."""
|
||||
gateway = await _require_gateway(
|
||||
session,
|
||||
gateway_id=gateway_id,
|
||||
|
||||
Reference in New Issue
Block a user