refactor: enhance gateway agent handling with dedicated OpenClaw agent IDs

This commit is contained in:
Abhimanyu Saharan
2026-02-10 01:33:01 +05:30
parent 822b13e6eb
commit 50f71960de
6 changed files with 406 additions and 82 deletions

View File

@@ -956,29 +956,36 @@ async def list_agents(
board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False) board_ids = await list_accessible_board_ids(session, member=ctx.member, write=False)
if board_id is not None and board_id not in set(board_ids): if board_id is not None and board_id not in set(board_ids):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if not board_ids: base_filters: list[ColumnElement[bool]] = []
statement = select(Agent).where(col(Agent.id).is_(None)) if board_ids:
base_filters.append(col(Agent.board_id).in_(board_ids))
if is_org_admin(ctx.member):
gateways = await Gateway.objects.filter_by(
organization_id=ctx.organization.id,
).all(session)
gateway_keys = [gateway_agent_session_key(gateway) for gateway in gateways]
if gateway_keys:
base_filters.append(col(Agent.openclaw_session_id).in_(gateway_keys))
if base_filters:
if len(base_filters) == 1:
statement = select(Agent).where(base_filters[0])
else:
statement = select(Agent).where(or_(*base_filters))
else: else:
base_filter: ColumnElement[bool] = col(Agent.board_id).in_(board_ids) statement = select(Agent).where(col(Agent.id).is_(None))
if is_org_admin(ctx.member):
gateways = await Gateway.objects.filter_by(
organization_id=ctx.organization.id,
).all(session)
gateway_keys = [gateway_agent_session_key(gateway) for gateway in gateways]
if gateway_keys:
base_filter = or_(
base_filter,
col(Agent.openclaw_session_id).in_(gateway_keys),
)
statement = select(Agent).where(base_filter)
if board_id is not None: if board_id is not None:
statement = statement.where(col(Agent.board_id) == board_id) statement = statement.where(col(Agent.board_id) == board_id)
if gateway_id is not None: if gateway_id is not None:
gateway = await Gateway.objects.by_id(gateway_id).first(session) gateway = await Gateway.objects.by_id(gateway_id).first(session)
if gateway is None or gateway.organization_id != ctx.organization.id: if gateway is None or gateway.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
statement = statement.join(Board, col(Agent.board_id) == col(Board.id)).where( gateway_main_key = gateway_agent_session_key(gateway)
col(Board.gateway_id) == gateway_id, gateway_board_ids = select(Board.id).where(col(Board.gateway_id) == gateway_id)
statement = statement.where(
or_(
col(Agent.board_id).in_(gateway_board_ids),
col(Agent.openclaw_session_id) == gateway_main_key,
),
) )
statement = statement.order_by(col(Agent.created_at).desc()) statement = statement.order_by(col(Agent.created_at).desc())

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@@ -17,9 +18,17 @@ from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import get_session from app.db.session import get_session
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.integrations.openclaw_gateway import (
OpenClawGatewayError,
ensure_session,
openclaw_call,
send_message,
)
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent from app.models.agents import Agent
from app.models.approvals import Approval
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.models.tasks import Task
from app.schemas.common import OkResponse from app.schemas.common import OkResponse
from app.schemas.gateways import ( from app.schemas.gateways import (
GatewayCreate, GatewayCreate,
@@ -34,7 +43,11 @@ from app.services.agent_provisioning import (
ProvisionOptions, ProvisionOptions,
provision_main_agent, provision_main_agent,
) )
from app.services.gateway_agents import gateway_agent_session_key, gateway_agent_session_key_for_id from app.services.gateway_agents import (
gateway_agent_session_key,
gateway_agent_session_key_for_id,
gateway_openclaw_agent_id,
)
from app.services.template_sync import GatewayTemplateSyncOptions from app.services.template_sync import GatewayTemplateSyncOptions
from app.services.template_sync import sync_gateway_templates as sync_gateway_templates_service from app.services.template_sync import sync_gateway_templates as sync_gateway_templates_service
@@ -42,6 +55,7 @@ if TYPE_CHECKING:
from fastapi_pagination.limit_offset import LimitOffsetPage from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.users import User
from app.services.organizations import OrganizationContext from app.services.organizations import OrganizationContext
router = APIRouter(prefix="/gateways", tags=["gateways"]) router = APIRouter(prefix="/gateways", tags=["gateways"])
@@ -54,6 +68,7 @@ ROTATE_TOKENS_QUERY = Query(default=False)
FORCE_BOOTSTRAP_QUERY = Query(default=False) FORCE_BOOTSTRAP_QUERY = Query(default=False)
BOARD_ID_QUERY = Query(default=None) BOARD_ID_QUERY = Query(default=None)
_RUNTIME_TYPE_REFERENCES = (UUID,) _RUNTIME_TYPE_REFERENCES = (UUID,)
logger = logging.getLogger(__name__)
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -89,6 +104,14 @@ def _main_agent_name(gateway: Gateway) -> str:
return f"{gateway.name} Gateway Agent" return f"{gateway.name} Gateway Agent"
def _gateway_identity_profile() -> dict[str, str]:
return {
"role": "Gateway Agent",
"communication_style": "direct, concise, practical",
"emoji": ":compass:",
}
async def _require_gateway( async def _require_gateway(
session: AsyncSession, session: AsyncSession,
*, *,
@@ -149,21 +172,19 @@ async def _find_main_agent(
return None return None
async def _ensure_main_agent( async def _upsert_main_agent_record(
session: AsyncSession, session: AsyncSession,
gateway: Gateway, gateway: Gateway,
auth: AuthContext,
*, *,
previous: tuple[str | None, str | None] | None = None, previous: tuple[str | None, str | None] | None = None,
action: str = "provision", ) -> tuple[Agent, bool]:
) -> Agent | None: changed = False
if not gateway.url:
return None
session_key = gateway_agent_session_key(gateway) session_key = gateway_agent_session_key(gateway)
if gateway.main_session_key != session_key: if gateway.main_session_key != session_key:
gateway.main_session_key = session_key gateway.main_session_key = session_key
gateway.updated_at = utcnow() gateway.updated_at = utcnow()
session.add(gateway) session.add(gateway)
changed = True
agent = await _find_main_agent( agent = await _find_main_agent(
session, session,
gateway, gateway,
@@ -178,15 +199,112 @@ async def _ensure_main_agent(
is_board_lead=False, is_board_lead=False,
openclaw_session_id=session_key, openclaw_session_id=session_key,
heartbeat_config=DEFAULT_HEARTBEAT_CONFIG.copy(), heartbeat_config=DEFAULT_HEARTBEAT_CONFIG.copy(),
identity_profile={ identity_profile=_gateway_identity_profile(),
"role": "Gateway Agent",
"communication_style": "direct, concise, practical",
"emoji": ":compass:",
},
) )
session.add(agent) session.add(agent)
agent.name = _main_agent_name(gateway) changed = True
agent.openclaw_session_id = session_key if agent.board_id is not None:
agent.board_id = None
changed = True
if agent.is_board_lead:
agent.is_board_lead = False
changed = True
if agent.name != _main_agent_name(gateway):
agent.name = _main_agent_name(gateway)
changed = True
if agent.openclaw_session_id != session_key:
agent.openclaw_session_id = session_key
changed = True
if agent.heartbeat_config is None:
agent.heartbeat_config = DEFAULT_HEARTBEAT_CONFIG.copy()
changed = True
if agent.identity_profile is None:
agent.identity_profile = _gateway_identity_profile()
changed = True
if not agent.status:
agent.status = "provisioning"
changed = True
if changed:
agent.updated_at = utcnow()
session.add(agent)
return agent, changed
async def _ensure_gateway_agents_exist(
session: AsyncSession,
gateways: list[Gateway],
) -> None:
for gateway in gateways:
agent, gateway_changed = await _upsert_main_agent_record(session, gateway)
has_gateway_entry = await _gateway_has_main_agent_entry(gateway)
needs_provision = gateway_changed or not bool(agent.agent_token_hash) or not has_gateway_entry
if needs_provision:
await _provision_main_agent_record(
session,
gateway,
agent,
user=None,
action="provision",
notify=False,
)
def _extract_agent_id_from_entry(item: object) -> str | None:
if isinstance(item, str):
value = item.strip()
return value or None
if not isinstance(item, dict):
return None
for key in ("id", "agentId", "agent_id"):
raw = item.get(key)
if isinstance(raw, str) and raw.strip():
return raw.strip()
return None
def _extract_config_agents_list(payload: object) -> list[object]:
if not isinstance(payload, dict):
return []
data = payload.get("config") or payload.get("parsed") or {}
if not isinstance(data, dict):
return []
agents = data.get("agents") or {}
if isinstance(agents, list):
return [item for item in agents]
if not isinstance(agents, dict):
return []
agents_list = agents.get("list") or []
if not isinstance(agents_list, list):
return []
return [item for item in agents_list]
async def _gateway_has_main_agent_entry(gateway: Gateway) -> bool:
if not gateway.url:
return False
config = GatewayClientConfig(url=gateway.url, token=gateway.token)
target_id = gateway_openclaw_agent_id(gateway)
try:
payload = await openclaw_call("config.get", config=config)
except OpenClawGatewayError:
# Avoid treating transient gateway connectivity issues as a missing agent entry.
return True
for item in _extract_config_agents_list(payload):
if _extract_agent_id_from_entry(item) == target_id:
return True
return False
async def _provision_main_agent_record(
session: AsyncSession,
gateway: Gateway,
agent: Agent,
*,
user: User | None,
action: str,
notify: bool,
) -> Agent:
session_key = gateway_agent_session_key(gateway)
raw_token = generate_agent_token() raw_token = generate_agent_token()
agent.agent_token_hash = hash_agent_token(raw_token) agent.agent_token_hash = hash_agent_token(raw_token)
agent.provision_requested_at = utcnow() agent.provision_requested_at = utcnow()
@@ -197,13 +315,15 @@ async def _ensure_main_agent(
session.add(agent) session.add(agent)
await session.commit() await session.commit()
await session.refresh(agent) await session.refresh(agent)
if not gateway.url:
return agent
try: try:
await provision_main_agent( await provision_main_agent(
agent, agent,
MainAgentProvisionRequest( MainAgentProvisionRequest(
gateway=gateway, gateway=gateway,
auth_token=raw_token, auth_token=raw_token,
user=auth.user, user=user,
session_key=session_key, session_key=session_key,
options=ProvisionOptions(action=action), options=ProvisionOptions(action=action),
), ),
@@ -213,29 +333,117 @@ async def _ensure_main_agent(
config=GatewayClientConfig(url=gateway.url, token=gateway.token), config=GatewayClientConfig(url=gateway.url, token=gateway.token),
label=agent.name, label=agent.name,
) )
await send_message( if notify:
( await send_message(
f"Hello {agent.name}. Your gateway provisioning was updated.\n\n" (
"Please re-read AGENTS.md, USER.md, HEARTBEAT.md, and TOOLS.md. " f"Hello {agent.name}. Your gateway provisioning was updated.\n\n"
"If BOOTSTRAP.md exists, run it once then delete it. " "Please re-read AGENTS.md, USER.md, HEARTBEAT.md, and TOOLS.md. "
"Begin heartbeats after startup." "If BOOTSTRAP.md exists, run it once then delete it. "
), "Begin heartbeats after startup."
session_key=session_key, ),
config=GatewayClientConfig(url=gateway.url, token=gateway.token), session_key=session_key,
deliver=True, config=GatewayClientConfig(url=gateway.url, token=gateway.token),
deliver=True,
)
except OpenClawGatewayError as exc:
logger.warning(
"gateway.main_agent.provision_failed_gateway gateway_id=%s agent_id=%s error=%s",
gateway.id,
agent.id,
str(exc),
)
except (OSError, RuntimeError, ValueError) as exc:
logger.warning(
"gateway.main_agent.provision_failed gateway_id=%s agent_id=%s error=%s",
gateway.id,
agent.id,
str(exc),
)
except Exception as exc: # pragma: no cover - defensive fallback
logger.warning(
"gateway.main_agent.provision_failed_unexpected gateway_id=%s agent_id=%s "
"error_type=%s error=%s",
gateway.id,
agent.id,
exc.__class__.__name__,
str(exc),
) )
except OpenClawGatewayError:
# Best-effort provisioning.
pass
return agent return agent
async def _ensure_main_agent(
session: AsyncSession,
gateway: Gateway,
auth: AuthContext,
*,
previous: tuple[str | None, str | None] | None = None,
action: str = "provision",
) -> Agent:
agent, _ = await _upsert_main_agent_record(
session,
gateway,
previous=previous,
)
return await _provision_main_agent_record(
session,
gateway,
agent,
user=auth.user,
action=action,
notify=True,
)
async def _clear_agent_foreign_keys(
session: AsyncSession,
*,
agent_id: UUID,
) -> None:
now = utcnow()
await crud.update_where(
session,
Task,
col(Task.assigned_agent_id) == agent_id,
col(Task.status) == "in_progress",
assigned_agent_id=None,
status="inbox",
in_progress_at=None,
updated_at=now,
commit=False,
)
await crud.update_where(
session,
Task,
col(Task.assigned_agent_id) == agent_id,
col(Task.status) != "in_progress",
assigned_agent_id=None,
updated_at=now,
commit=False,
)
await crud.update_where(
session,
ActivityEvent,
col(ActivityEvent.agent_id) == agent_id,
agent_id=None,
commit=False,
)
await crud.update_where(
session,
Approval,
col(Approval.agent_id) == agent_id,
agent_id=None,
commit=False,
)
@router.get("", response_model=DefaultLimitOffsetPage[GatewayRead]) @router.get("", response_model=DefaultLimitOffsetPage[GatewayRead])
async def list_gateways( async def list_gateways(
session: AsyncSession = SESSION_DEP, session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP, ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> LimitOffsetPage[GatewayRead]: ) -> LimitOffsetPage[GatewayRead]:
"""List gateways for the caller's organization.""" """List gateways for the caller's organization."""
gateways = await Gateway.objects.filter_by(organization_id=ctx.organization.id).all(session)
await _ensure_gateway_agents_exist(session, gateways)
statement = ( statement = (
Gateway.objects.filter_by(organization_id=ctx.organization.id) Gateway.objects.filter_by(organization_id=ctx.organization.id)
.order_by(col(Gateway.created_at).desc()) .order_by(col(Gateway.created_at).desc())
@@ -269,11 +477,13 @@ async def get_gateway(
ctx: OrganizationContext = ORG_ADMIN_DEP, ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> Gateway: ) -> Gateway:
"""Return one gateway by id for the caller's organization.""" """Return one gateway by id for the caller's organization."""
return await _require_gateway( gateway = await _require_gateway(
session, session,
gateway_id=gateway_id, gateway_id=gateway_id,
organization_id=ctx.organization.id, organization_id=ctx.organization.id,
) )
await _ensure_gateway_agents_exist(session, [gateway])
return gateway
@router.patch("/{gateway_id}", response_model=GatewayRead) @router.patch("/{gateway_id}", response_model=GatewayRead)
@@ -318,6 +528,7 @@ async def sync_gateway_templates(
gateway_id=gateway_id, gateway_id=gateway_id,
organization_id=ctx.organization.id, organization_id=ctx.organization.id,
) )
await _ensure_gateway_agents_exist(session, [gateway])
return await sync_gateway_templates_service( return await sync_gateway_templates_service(
session, session,
gateway, gateway,
@@ -344,5 +555,21 @@ async def delete_gateway(
gateway_id=gateway_id, gateway_id=gateway_id,
organization_id=ctx.organization.id, organization_id=ctx.organization.id,
) )
await crud.delete(session, gateway) gateway_session_key = gateway_agent_session_key(gateway)
main_agent = await _find_main_agent(session, gateway)
if main_agent is not None:
await _clear_agent_foreign_keys(session, agent_id=main_agent.id)
await session.delete(main_agent)
duplicate_main_agents = await Agent.objects.filter_by(
openclaw_session_id=gateway_session_key,
).all(session)
for agent in duplicate_main_agents:
if main_agent is not None and agent.id == main_agent.id:
continue
await _clear_agent_foreign_keys(session, agent_id=agent.id)
await session.delete(agent)
await session.delete(gateway)
await session.commit()
return OkResponse() return OkResponse()

View File

@@ -16,7 +16,11 @@ from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoes
from app.core.config import settings from app.core.config import settings
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, openclaw_call from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, openclaw_call
from app.services.gateway_agents import gateway_agent_session_key from app.services.gateway_agents import (
gateway_agent_session_key,
gateway_openclaw_agent_id,
parse_gateway_agent_session_key,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from app.models.agents import Agent from app.models.agents import Agent
@@ -146,6 +150,9 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None:
value = (session_key or "").strip() value = (session_key or "").strip()
if not value: if not value:
return None return None
# Dedicated Mission Control gateway-agent session keys are not gateway config agent ids.
if parse_gateway_agent_session_key(value) is not None:
return None
if not value.startswith("agent:"): if not value.startswith("agent:"):
return None return None
parts = value.split(":") parts = value.split(":")
@@ -880,22 +887,14 @@ async def provision_main_agent(
label=agent.name or "Gateway Agent", label=agent.name or "Gateway Agent",
) )
agent_id = _agent_id_from_session_key(session_key) # Keep gateway default agent intact and use a dedicated OpenClaw agent id for Mission Control.
if agent_id: if not gateway.workspace_root:
if not gateway.workspace_root: msg = "gateway_workspace_root is required"
msg = "gateway_workspace_root is required" raise ValueError(msg)
raise ValueError(msg) agent_id = gateway_openclaw_agent_id(gateway)
workspace_path = _workspace_path(agent, gateway.workspace_root) workspace_path = _workspace_path(agent, gateway.workspace_root)
heartbeat = _heartbeat_config(agent) heartbeat = _heartbeat_config(agent)
await _patch_gateway_agent_list(agent_id, workspace_path, heartbeat, client_config) await _patch_gateway_agent_list(agent_id, workspace_path, heartbeat, client_config)
else:
agent_id = await _gateway_default_agent_id(
client_config,
fallback_session_key=session_key,
)
if not agent_id:
msg = "Unable to resolve gateway main agent id"
raise OpenClawGatewayError(msg)
context = _build_main_context(agent, gateway, request.auth_token, request.user) context = _build_main_context(agent, gateway, request.auth_token, request.user)
supported = set(await _supported_gateway_files(client_config)) supported = set(await _supported_gateway_files(client_config))

View File

@@ -8,6 +8,7 @@ from app.models.gateways import Gateway
_GATEWAY_AGENT_PREFIX = "agent:gateway-" _GATEWAY_AGENT_PREFIX = "agent:gateway-"
_GATEWAY_AGENT_SUFFIX = ":main" _GATEWAY_AGENT_SUFFIX = ":main"
_GATEWAY_OPENCLAW_AGENT_PREFIX = "mc-gateway-"
def gateway_agent_session_key_for_id(gateway_id: UUID) -> str: def gateway_agent_session_key_for_id(gateway_id: UUID) -> str:
@@ -20,6 +21,16 @@ def gateway_agent_session_key(gateway: Gateway) -> str:
return gateway_agent_session_key_for_id(gateway.id) return gateway_agent_session_key_for_id(gateway.id)
def gateway_openclaw_agent_id_for_id(gateway_id: UUID) -> str:
"""Return the dedicated OpenClaw config `agentId` for a gateway agent."""
return f"{_GATEWAY_OPENCLAW_AGENT_PREFIX}{gateway_id}"
def gateway_openclaw_agent_id(gateway: Gateway) -> str:
"""Return the dedicated OpenClaw config `agentId` for a gateway agent."""
return gateway_openclaw_agent_id_for_id(gateway.id)
def parse_gateway_agent_session_key(session_key: str | None) -> UUID | None: def parse_gateway_agent_session_key(session_key: str | None) -> UUID | None:
"""Parse a gateway id from a dedicated gateway-agent session key.""" """Parse a gateway id from a dedicated gateway-agent session key."""
value = (session_key or "").strip() value = (session_key or "").strip()

View File

@@ -31,7 +31,11 @@ from app.services.agent_provisioning import (
provision_agent, provision_agent,
provision_main_agent, provision_main_agent,
) )
from app.services.gateway_agents import gateway_agent_session_key from app.services.gateway_agents import (
gateway_agent_session_key,
gateway_openclaw_agent_id,
parse_gateway_agent_session_key,
)
_TOOLS_KV_RE = re.compile(r"^(?P<key>[A-Z0-9_]+)=(?P<value>.*)$") _TOOLS_KV_RE = re.compile(r"^(?P<key>[A-Z0-9_]+)=(?P<value>.*)$")
SESSION_KEY_PARTS_MIN = 2 SESSION_KEY_PARTS_MIN = 2
@@ -179,6 +183,9 @@ def _agent_id_from_session_key(session_key: str | None) -> str | None:
value = (session_key or "").strip() value = (session_key or "").strip()
if not value: if not value:
return None return None
# Dedicated Mission Control gateway-agent session keys are not gateway config agent ids.
if parse_gateway_agent_session_key(value) is not None:
return None
if not value.startswith("agent:"): if not value.startswith("agent:"):
return None return None
parts = value.split(":") parts = value.split(":")
@@ -314,6 +321,7 @@ async def _gateway_default_agent_id(
return agent_id return agent_id
except OpenClawGatewayError: except OpenClawGatewayError:
pass pass
# Avoid falling back to dedicated gateway session keys, which are not agent ids.
return _agent_id_from_session_key(fallback_session_key) return _agent_id_from_session_key(fallback_session_key)
@@ -533,22 +541,7 @@ async def _sync_main_agent(
message=("Gateway agent record not found; " "skipping gateway agent template sync."), message=("Gateway agent record not found; " "skipping gateway agent template sync."),
) )
return True return True
try: main_gateway_agent_id = gateway_openclaw_agent_id(ctx.gateway)
main_gateway_agent_id = await _gateway_default_agent_id(
ctx.config,
fallback_session_key=main_session_key,
backoff=ctx.backoff,
)
except TimeoutError as exc:
_append_sync_error(result, agent=main_agent, message=str(exc))
return True
if not main_gateway_agent_id:
_append_sync_error(
result,
agent=main_agent,
message="Unable to resolve gateway agent id.",
)
return True
token, fatal = await _resolve_agent_auth_token( token, fatal = await _resolve_agent_auth_token(
ctx, ctx,

View File

@@ -2,9 +2,16 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass, field
from uuid import UUID, uuid4
import pytest
from app.services import agent_provisioning from app.services import agent_provisioning
from app.services.gateway_agents import (
gateway_agent_session_key_for_id,
gateway_openclaw_agent_id_for_id,
)
def test_slugify_normalizes_and_trims(): def test_slugify_normalizes_and_trims():
@@ -28,6 +35,11 @@ def test_agent_id_from_session_key_parses_agent_prefix():
assert agent_provisioning._agent_id_from_session_key("agent:riya:main") == "riya" assert agent_provisioning._agent_id_from_session_key("agent:riya:main") == "riya"
def test_agent_id_from_session_key_ignores_gateway_main_session_key():
session_key = gateway_agent_session_key_for_id(uuid4())
assert agent_provisioning._agent_id_from_session_key(session_key) is None
def test_extract_agent_id_supports_lists_and_dicts(): def test_extract_agent_id_supports_lists_and_dicts():
assert agent_provisioning._extract_agent_id(["", " ", "abc"]) == "abc" assert agent_provisioning._extract_agent_id(["", " ", "abc"]) == "abc"
assert agent_provisioning._extract_agent_id([{"agent_id": "xyz"}]) == "xyz" assert agent_provisioning._extract_agent_id([{"agent_id": "xyz"}]) == "xyz"
@@ -55,6 +67,10 @@ class _AgentStub:
openclaw_session_id: str | None = None openclaw_session_id: str | None = None
heartbeat_config: dict | None = None heartbeat_config: dict | None = None
is_board_lead: bool = False is_board_lead: bool = False
id: UUID = field(default_factory=uuid4)
identity_profile: dict | None = None
identity_template: str | None = None
soul_template: str | None = None
def test_agent_key_uses_session_key_when_present(monkeypatch): def test_agent_key_uses_session_key_when_present(monkeypatch):
@@ -64,3 +80,74 @@ def test_agent_key_uses_session_key_when_present(monkeypatch):
monkeypatch.setattr(agent_provisioning, "_slugify", lambda value: "slugged") monkeypatch.setattr(agent_provisioning, "_slugify", lambda value: "slugged")
agent2 = _AgentStub(name="Alice", openclaw_session_id=None) agent2 = _AgentStub(name="Alice", openclaw_session_id=None)
assert agent_provisioning._agent_key(agent2) == "slugged" assert agent_provisioning._agent_key(agent2) == "slugged"
@dataclass
class _GatewayStub:
id: UUID
name: str
url: str
token: str | None
workspace_root: str
main_session_key: str
@pytest.mark.asyncio
async def test_provision_main_agent_uses_dedicated_openclaw_agent_id(monkeypatch):
gateway_id = uuid4()
session_key = gateway_agent_session_key_for_id(gateway_id)
gateway = _GatewayStub(
id=gateway_id,
name="Acme",
url="ws://gateway.example/ws",
token=None,
workspace_root="/tmp/openclaw",
main_session_key=session_key,
)
agent = _AgentStub(name="Acme Gateway Agent", openclaw_session_id=session_key)
captured: dict[str, object] = {}
async def _fake_ensure_session(*args, **kwargs):
return None
async def _fake_patch_gateway_agent_list(agent_id, workspace_path, heartbeat, config):
captured["patched_agent_id"] = agent_id
captured["workspace_path"] = workspace_path
async def _fake_supported_gateway_files(config):
return set()
async def _fake_gateway_agent_files_index(agent_id, config):
captured["files_index_agent_id"] = agent_id
return {}
def _fake_render_agent_files(*args, **kwargs):
return {}
async def _fake_set_agent_files(*args, **kwargs):
return None
monkeypatch.setattr(agent_provisioning, "ensure_session", _fake_ensure_session)
monkeypatch.setattr(agent_provisioning, "_patch_gateway_agent_list", _fake_patch_gateway_agent_list)
monkeypatch.setattr(agent_provisioning, "_supported_gateway_files", _fake_supported_gateway_files)
monkeypatch.setattr(
agent_provisioning,
"_gateway_agent_files_index",
_fake_gateway_agent_files_index,
)
monkeypatch.setattr(agent_provisioning, "_render_agent_files", _fake_render_agent_files)
monkeypatch.setattr(agent_provisioning, "_set_agent_files", _fake_set_agent_files)
await agent_provisioning.provision_main_agent(
agent,
agent_provisioning.MainAgentProvisionRequest(
gateway=gateway,
auth_token="secret-token",
user=None,
session_key=session_key,
),
)
expected_agent_id = gateway_openclaw_agent_id_for_id(gateway_id)
assert captured["patched_agent_id"] == expected_agent_id
assert captured["files_index_agent_id"] == expected_agent_id