refactor: update internal helpers and improve slugify function usage

This commit is contained in:
Abhimanyu Saharan
2026-02-11 00:27:44 +05:30
parent b038d0df4c
commit f4161494d9
6 changed files with 83 additions and 198 deletions

View File

@@ -34,7 +34,8 @@ from app.services.openclaw.exceptions import (
) )
from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig
from app.services.openclaw.gateway_rpc import OpenClawGatewayError, openclaw_call from app.services.openclaw.gateway_rpc import OpenClawGatewayError, openclaw_call
from app.services.openclaw.internal import agent_key, with_coordination_gateway_retry from app.services.openclaw.internal.agent_key import agent_key
from app.services.openclaw.internal.retry import with_coordination_gateway_retry
from app.services.openclaw.policies import OpenClawAuthorizationPolicy from app.services.openclaw.policies import OpenClawAuthorizationPolicy
from app.services.openclaw.provisioning_db import ( from app.services.openclaw.provisioning_db import (
LeadAgentOptions, LeadAgentOptions,

View File

@@ -1,6 +1,7 @@
"""Internal typed helpers shared across OpenClaw service modules.""" """Internal typed helpers shared across OpenClaw service modules.
from .agent_key import agent_key Import submodules directly (for example: ``app.services.openclaw.internal.agent_key``)
from .retry import with_coordination_gateway_retry to avoid shadowing submodule names with re-exported symbols.
"""
__all__ = ["agent_key", "with_coordination_gateway_retry"] __all__: list[str] = []

View File

@@ -9,7 +9,7 @@ from app.models.agents import Agent
from app.services.openclaw.constants import _SESSION_KEY_PARTS_MIN from app.services.openclaw.constants import _SESSION_KEY_PARTS_MIN
def _slugify(value: str) -> str: def slugify(value: str) -> str:
slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-") slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-")
return slug or uuid4().hex return slug or uuid4().hex
@@ -21,4 +21,4 @@ def agent_key(agent: Agent) -> str:
parts = session_key.split(":") parts = session_key.split(":")
if len(parts) >= _SESSION_KEY_PARTS_MIN and parts[1]: if len(parts) >= _SESSION_KEY_PARTS_MIN and parts[1]:
return parts[1] return parts[1]
return _slugify(agent.name) return slugify(agent.name)

View File

@@ -8,12 +8,10 @@ DB-backed workflows (template sync, lead-agent record creation) live in
from __future__ import annotations from __future__ import annotations
import json import json
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from uuid import uuid4
from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoescape from jinja2 import Environment, FileSystemLoader, StrictUndefined, select_autoescape
@@ -40,7 +38,8 @@ from app.services.openclaw.gateway_rpc import (
openclaw_call, openclaw_call,
send_message, send_message,
) )
from app.services.openclaw.internal import agent_key as _agent_key from app.services.openclaw.internal.agent_key import agent_key as _agent_key
from app.services.openclaw.internal.agent_key import slugify
from app.services.openclaw.shared import GatewayAgentIdentity from app.services.openclaw.shared import GatewayAgentIdentity
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -78,11 +77,6 @@ def _templates_root() -> Path:
return _repo_root() / "templates" return _repo_root() / "templates"
def _slugify(value: str) -> str:
slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-")
return slug or uuid4().hex
def _heartbeat_config(agent: Agent) -> dict[str, Any]: def _heartbeat_config(agent: Agent) -> dict[str, Any]:
merged = DEFAULT_HEARTBEAT_CONFIG.copy() merged = DEFAULT_HEARTBEAT_CONFIG.copy()
if isinstance(agent.heartbeat_config, dict): if isinstance(agent.heartbeat_config, dict):
@@ -134,7 +128,58 @@ def _workspace_path(agent: Agent, workspace_root: str) -> str:
# lead agents (session key includes board id) even if multiple boards share the same # lead agents (session key includes board id) even if multiple boards share the same
# display name (e.g. "Lead Agent"). # display name (e.g. "Lead Agent").
key = _agent_key(agent) key = _agent_key(agent)
return f"{root}/workspace-{_slugify(key)}" return f"{root}/workspace-{slugify(key)}"
def _preferred_name(user: User | None) -> str:
preferred_name = (user.preferred_name or "") if user else ""
if preferred_name:
preferred_name = preferred_name.strip().split()[0]
return preferred_name
def _user_context(user: User | None) -> dict[str, str]:
return {
"user_name": (user.name or "") if user else "",
"user_preferred_name": _preferred_name(user),
"user_pronouns": (user.pronouns or "") if user else "",
"user_timezone": (user.timezone or "") if user else "",
"user_notes": (user.notes or "") if user else "",
"user_context": (user.context or "") if user else "",
}
def _normalized_identity_profile(agent: Agent) -> dict[str, str]:
identity_profile: dict[str, Any] = {}
if isinstance(agent.identity_profile, dict):
identity_profile = agent.identity_profile
normalized_identity: dict[str, str] = {}
for key, value in identity_profile.items():
if value is None:
continue
if isinstance(value, list):
parts = [str(item).strip() for item in value if str(item).strip()]
if not parts:
continue
normalized_identity[key] = ", ".join(parts)
continue
text = str(value).strip()
if text:
normalized_identity[key] = text
return normalized_identity
def _identity_context(agent: Agent) -> dict[str, str]:
normalized_identity = _normalized_identity_profile(agent)
identity_context = {
context_key: normalized_identity.get(field, DEFAULT_IDENTITY_PROFILE[field])
for field, context_key in IDENTITY_PROFILE_FIELDS.items()
}
extra_identity_context = {
context_key: normalized_identity.get(field, "")
for field, context_key in EXTRA_IDENTITY_PROFILE_FIELDS.items()
}
return {**identity_context, **extra_identity_context}
def _build_context( def _build_context(
@@ -153,33 +198,8 @@ def _build_context(
session_key = agent.openclaw_session_id or "" session_key = agent.openclaw_session_id or ""
base_url = settings.base_url or "REPLACE_WITH_BASE_URL" base_url = settings.base_url or "REPLACE_WITH_BASE_URL"
main_session_key = GatewayAgentIdentity.session_key(gateway) main_session_key = GatewayAgentIdentity.session_key(gateway)
identity_profile: dict[str, Any] = {} identity_context = _identity_context(agent)
if isinstance(agent.identity_profile, dict): user_context = _user_context(user)
identity_profile = agent.identity_profile
normalized_identity: dict[str, str] = {}
for key, value in identity_profile.items():
if value is None:
continue
if isinstance(value, list):
parts = [str(item).strip() for item in value if str(item).strip()]
if not parts:
continue
normalized_identity[key] = ", ".join(parts)
continue
text = str(value).strip()
if text:
normalized_identity[key] = text
identity_context = {
context_key: normalized_identity.get(field, DEFAULT_IDENTITY_PROFILE[field])
for field, context_key in IDENTITY_PROFILE_FIELDS.items()
}
extra_identity_context = {
context_key: normalized_identity.get(field, "")
for field, context_key in EXTRA_IDENTITY_PROFILE_FIELDS.items()
}
preferred_name = (user.preferred_name or "") if user else ""
if preferred_name:
preferred_name = preferred_name.strip().split()[0]
return { return {
"agent_name": agent.name, "agent_name": agent.name,
"agent_id": agent_id, "agent_id": agent_id,
@@ -197,14 +217,8 @@ def _build_context(
"auth_token": auth_token, "auth_token": auth_token,
"main_session_key": main_session_key, "main_session_key": main_session_key,
"workspace_root": workspace_root, "workspace_root": workspace_root,
"user_name": (user.name or "") if user else "", **user_context,
"user_preferred_name": preferred_name,
"user_pronouns": (user.pronouns or "") if user else "",
"user_timezone": (user.timezone or "") if user else "",
"user_notes": (user.notes or "") if user else "",
"user_context": (user.context or "") if user else "",
**identity_context, **identity_context,
**extra_identity_context,
} }
@@ -215,33 +229,8 @@ def _build_main_context(
user: User | None, user: User | None,
) -> dict[str, str]: ) -> dict[str, str]:
base_url = settings.base_url or "REPLACE_WITH_BASE_URL" base_url = settings.base_url or "REPLACE_WITH_BASE_URL"
identity_profile: dict[str, Any] = {} identity_context = _identity_context(agent)
if isinstance(agent.identity_profile, dict): user_context = _user_context(user)
identity_profile = agent.identity_profile
normalized_identity: dict[str, str] = {}
for key, value in identity_profile.items():
if value is None:
continue
if isinstance(value, list):
parts = [str(item).strip() for item in value if str(item).strip()]
if not parts:
continue
normalized_identity[key] = ", ".join(parts)
continue
text = str(value).strip()
if text:
normalized_identity[key] = text
identity_context = {
context_key: normalized_identity.get(field, DEFAULT_IDENTITY_PROFILE[field])
for field, context_key in IDENTITY_PROFILE_FIELDS.items()
}
extra_identity_context = {
context_key: normalized_identity.get(field, "")
for field, context_key in EXTRA_IDENTITY_PROFILE_FIELDS.items()
}
preferred_name = (user.preferred_name or "") if user else ""
if preferred_name:
preferred_name = preferred_name.strip().split()[0]
return { return {
"agent_name": agent.name, "agent_name": agent.name,
"agent_id": str(agent.id), "agent_id": str(agent.id),
@@ -250,14 +239,8 @@ def _build_main_context(
"auth_token": auth_token, "auth_token": auth_token,
"main_session_key": GatewayAgentIdentity.session_key(gateway), "main_session_key": GatewayAgentIdentity.session_key(gateway),
"workspace_root": gateway.workspace_root or "", "workspace_root": gateway.workspace_root or "",
"user_name": (user.name or "") if user else "", **user_context,
"user_preferred_name": preferred_name,
"user_pronouns": (user.pronouns or "") if user else "",
"user_timezone": (user.timezone or "") if user else "",
"user_notes": (user.notes or "") if user else "",
"user_context": (user.context or "") if user else "",
**identity_context, **identity_context,
**extra_identity_context,
} }

View File

@@ -45,10 +45,7 @@ from app.schemas.common import OkResponse
from app.schemas.gateways import GatewayTemplatesSyncError, GatewayTemplatesSyncResult from app.schemas.gateways import GatewayTemplatesSyncError, GatewayTemplatesSyncResult
from app.services.activity_log import record_activity from app.services.activity_log import record_activity
from app.services.openclaw.constants import ( from app.services.openclaw.constants import (
_NON_TRANSIENT_GATEWAY_ERROR_MARKERS,
_SECURE_RANDOM,
_TOOLS_KV_RE, _TOOLS_KV_RE,
_TRANSIENT_GATEWAY_ERROR_MARKERS,
AGENT_SESSION_PREFIX, AGENT_SESSION_PREFIX,
DEFAULT_HEARTBEAT_CONFIG, DEFAULT_HEARTBEAT_CONFIG,
OFFLINE_AFTER, OFFLINE_AFTER,
@@ -59,7 +56,8 @@ from app.services.openclaw.gateway_rpc import (
ensure_session, ensure_session,
send_message, send_message,
) )
from app.services.openclaw.internal import agent_key as _agent_key from app.services.openclaw.internal.agent_key import agent_key as _agent_key
from app.services.openclaw.internal.retry import GatewayBackoff
from app.services.openclaw.policies import OpenClawAuthorizationPolicy from app.services.openclaw.policies import OpenClawAuthorizationPolicy
from app.services.openclaw.provisioning import ( from app.services.openclaw.provisioning import (
OpenClawGatewayControlPlane, OpenClawGatewayControlPlane,
@@ -77,7 +75,7 @@ from app.services.organizations import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence from collections.abc import AsyncIterator, Sequence
from fastapi_pagination.limit_offset import LimitOffsetPage from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import ColumnElement
@@ -272,7 +270,7 @@ class OpenClawProvisioningService:
session=self.session, session=self.session,
gateway=gateway, gateway=gateway,
control_plane=control_plane, control_plane=control_plane,
backoff=_GatewayBackoff(timeout_s=10 * 60, timeout_context="template sync"), backoff=GatewayBackoff(timeout_s=10 * 60, timeout_context="template sync"),
options=options, options=options,
provisioner=self._gateway, provisioner=self._gateway,
) )
@@ -325,110 +323,11 @@ class _SyncContext:
session: AsyncSession session: AsyncSession
gateway: Gateway gateway: Gateway
control_plane: OpenClawGatewayControlPlane control_plane: OpenClawGatewayControlPlane
backoff: _GatewayBackoff backoff: GatewayBackoff
options: GatewayTemplateSyncOptions options: GatewayTemplateSyncOptions
provisioner: OpenClawGatewayProvisioner provisioner: OpenClawGatewayProvisioner
def _is_transient_gateway_error(exc: Exception) -> bool:
if not isinstance(exc, OpenClawGatewayError):
return False
message = str(exc).lower()
if not message:
return False
if any(marker in message for marker in _NON_TRANSIENT_GATEWAY_ERROR_MARKERS):
return False
return ("503" in message and "websocket" in message) or any(
marker in message for marker in _TRANSIENT_GATEWAY_ERROR_MARKERS
)
def _gateway_timeout_message(
exc: OpenClawGatewayError,
*,
timeout_s: float,
context: str,
) -> str:
rounded_timeout = int(timeout_s)
timeout_text = f"{rounded_timeout} seconds"
if rounded_timeout >= 120:
timeout_text = f"{rounded_timeout // 60} minutes"
return f"Gateway unreachable after {timeout_text} ({context} timeout). Last error: {exc}"
class _GatewayBackoff:
def __init__(
self,
*,
timeout_s: float = 10 * 60,
base_delay_s: float = 0.75,
max_delay_s: float = 30.0,
jitter: float = 0.2,
timeout_context: str = "gateway operation",
) -> None:
self._timeout_s = timeout_s
self._base_delay_s = base_delay_s
self._max_delay_s = max_delay_s
self._jitter = jitter
self._timeout_context = timeout_context
self._delay_s = base_delay_s
def reset(self) -> None:
self._delay_s = self._base_delay_s
@staticmethod
async def _attempt(
fn: Callable[[], Awaitable[_T]],
) -> tuple[_T | None, OpenClawGatewayError | None]:
try:
return await fn(), None
except OpenClawGatewayError as exc:
return None, exc
async def run(self, fn: Callable[[], Awaitable[_T]]) -> _T:
deadline_s = asyncio.get_running_loop().time() + self._timeout_s
while True:
value, error = await self._attempt(fn)
if error is not None:
exc = error
if not _is_transient_gateway_error(exc):
raise exc
now = asyncio.get_running_loop().time()
remaining = deadline_s - now
if remaining <= 0:
raise TimeoutError(
_gateway_timeout_message(
exc,
timeout_s=self._timeout_s,
context=self._timeout_context,
),
) from exc
sleep_s = min(self._delay_s, remaining)
if self._jitter:
sleep_s *= 1.0 + _SECURE_RANDOM.uniform(
-self._jitter,
self._jitter,
)
sleep_s = max(0.0, min(sleep_s, remaining))
await asyncio.sleep(sleep_s)
self._delay_s = min(self._delay_s * 2.0, self._max_delay_s)
continue
self.reset()
if value is None:
msg = "Gateway retry produced no value without an error"
raise RuntimeError(msg)
return value
async def _with_gateway_retry(
fn: Callable[[], Awaitable[_T]],
*,
backoff: _GatewayBackoff,
) -> _T:
return await backoff.run(fn)
def _parse_tools_md(content: str) -> dict[str, str]: def _parse_tools_md(content: str) -> dict[str, str]:
values: dict[str, str] = {} values: dict[str, str] = {}
for raw in content.splitlines(): for raw in content.splitlines():
@@ -447,7 +346,7 @@ async def _get_agent_file(
agent_gateway_id: str, agent_gateway_id: str,
name: str, name: str,
control_plane: OpenClawGatewayControlPlane, control_plane: OpenClawGatewayControlPlane,
backoff: _GatewayBackoff | None = None, backoff: GatewayBackoff | None = None,
) -> str | None: ) -> str | None:
try: try:
@@ -475,7 +374,7 @@ async def _get_existing_auth_token(
*, *,
agent_gateway_id: str, agent_gateway_id: str,
control_plane: OpenClawGatewayControlPlane, control_plane: OpenClawGatewayControlPlane,
backoff: _GatewayBackoff | None = None, backoff: GatewayBackoff | None = None,
) -> str | None: ) -> str | None:
tools = await _get_agent_file( tools = await _get_agent_file(
agent_gateway_id=agent_gateway_id, agent_gateway_id=agent_gateway_id,
@@ -672,7 +571,7 @@ async def _sync_one_agent(
) )
return True return True
await _with_gateway_retry(_do_provision, backoff=ctx.backoff) await ctx.backoff.run(_do_provision)
result.agents_updated += 1 result.agents_updated += 1
except TimeoutError as exc: # pragma: no cover - gateway/network dependent except TimeoutError as exc: # pragma: no cover - gateway/network dependent
result.agents_skipped += 1 result.agents_skipped += 1
@@ -742,7 +641,7 @@ async def _sync_main_agent(
) )
return True return True
await _with_gateway_retry(_do_provision_main, backoff=ctx.backoff) await ctx.backoff.run(_do_provision_main)
except TimeoutError as exc: # pragma: no cover - gateway/network dependent except TimeoutError as exc: # pragma: no cover - gateway/network dependent
_append_sync_error(result, agent=main_agent, message=str(exc)) _append_sync_error(result, agent=main_agent, message=str(exc))
stop_sync = True stop_sync = True

View File

@@ -7,22 +7,23 @@ from uuid import UUID, uuid4
import pytest import pytest
import app.services.openclaw.internal.agent_key as agent_key_mod
import app.services.openclaw.provisioning as agent_provisioning import app.services.openclaw.provisioning as agent_provisioning
from app.services.openclaw.provisioning_db import AgentLifecycleService from app.services.openclaw.provisioning_db import AgentLifecycleService
from app.services.openclaw.shared import GatewayAgentIdentity from app.services.openclaw.shared import GatewayAgentIdentity
def test_slugify_normalizes_and_trims(): def test_slugify_normalizes_and_trims():
assert agent_provisioning._slugify("Hello, World") == "hello-world" assert agent_provisioning.slugify("Hello, World") == "hello-world"
assert agent_provisioning._slugify(" A B ") == "a-b" assert agent_provisioning.slugify(" A B ") == "a-b"
def test_slugify_falls_back_to_uuid_hex(monkeypatch): def test_slugify_falls_back_to_uuid_hex(monkeypatch):
class _FakeUuid: class _FakeUuid:
hex = "deadbeef" hex = "deadbeef"
monkeypatch.setattr(agent_provisioning, "uuid4", lambda: _FakeUuid()) monkeypatch.setattr(agent_key_mod, "uuid4", lambda: _FakeUuid())
assert agent_provisioning._slugify("!!!") == "deadbeef" assert agent_provisioning.slugify("!!!") == "deadbeef"
@dataclass @dataclass