refactor: streamline agent lifecycle management with new DB service helpers
This commit is contained in:
@@ -3,10 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from typing import TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
@@ -27,11 +26,13 @@ from app.schemas.gateway_coordination import (
|
||||
GatewayMainAskUserResponse,
|
||||
)
|
||||
from app.services.activity_log import record_activity
|
||||
from app.services.openclaw.db_service import OpenClawDBService
|
||||
from app.services.openclaw.exceptions import (
|
||||
GatewayOperation,
|
||||
map_gateway_error_message,
|
||||
map_gateway_error_to_http_exception,
|
||||
)
|
||||
from app.services.openclaw.gateway_dispatch import GatewayDispatchService
|
||||
from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig
|
||||
from app.services.openclaw.gateway_rpc import OpenClawGatewayError, openclaw_call
|
||||
from app.services.openclaw.internal.agent_key import agent_key
|
||||
@@ -42,43 +43,14 @@ from app.services.openclaw.provisioning_db import (
|
||||
LeadAgentRequest,
|
||||
OpenClawProvisioningService,
|
||||
)
|
||||
from app.services.openclaw.shared import (
|
||||
GatewayAgentIdentity,
|
||||
require_gateway_config_for_board,
|
||||
resolve_trace_id,
|
||||
send_gateway_agent_message,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.services.openclaw.shared import GatewayAgentIdentity
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class AbstractGatewayMessagingService(ABC):
|
||||
class AbstractGatewayMessagingService(OpenClawDBService, ABC):
|
||||
"""Shared gateway messaging primitives with retry semantics."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
@property
|
||||
def session(self) -> AsyncSession:
|
||||
return self._session
|
||||
|
||||
@session.setter
|
||||
def session(self, value: AsyncSession) -> None:
|
||||
self._session = value
|
||||
|
||||
@property
|
||||
def logger(self) -> logging.Logger:
|
||||
return self._logger
|
||||
|
||||
@logger.setter
|
||||
def logger(self, value: logging.Logger) -> None:
|
||||
self._logger = value
|
||||
|
||||
@staticmethod
|
||||
async def _with_gateway_retry(fn: Callable[[], Awaitable[_T]]) -> _T:
|
||||
return await with_coordination_gateway_retry(fn)
|
||||
@@ -93,7 +65,7 @@ class AbstractGatewayMessagingService(ABC):
|
||||
deliver: bool,
|
||||
) -> None:
|
||||
async def _do_send() -> bool:
|
||||
await send_gateway_agent_message(
|
||||
await GatewayDispatchService(self.session).send_agent_message(
|
||||
session_key=session_key,
|
||||
config=config,
|
||||
agent_name=agent_name,
|
||||
@@ -198,7 +170,7 @@ class GatewayCoordinationService(AbstractGatewayMessagingService):
|
||||
message: str,
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
trace_id = resolve_trace_id(correlation_id, prefix="coord.nudge")
|
||||
trace_id = GatewayDispatchService.resolve_trace_id(correlation_id, prefix="coord.nudge")
|
||||
self.logger.log(
|
||||
5,
|
||||
"gateway.coordination.nudge.start trace_id=%s board_id=%s actor_agent_id=%s "
|
||||
@@ -214,7 +186,9 @@ class GatewayCoordinationService(AbstractGatewayMessagingService):
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Target agent has no session key",
|
||||
)
|
||||
_gateway, config = await require_gateway_config_for_board(self.session, board)
|
||||
_gateway, config = await GatewayDispatchService(
|
||||
self.session
|
||||
).require_gateway_config_for_board(board)
|
||||
try:
|
||||
await self._dispatch_gateway_message(
|
||||
session_key=target.openclaw_session_id or "",
|
||||
@@ -276,7 +250,7 @@ class GatewayCoordinationService(AbstractGatewayMessagingService):
|
||||
target_agent_id: str,
|
||||
correlation_id: str | None = None,
|
||||
) -> str:
|
||||
trace_id = resolve_trace_id(correlation_id, prefix="coord.soul.read")
|
||||
trace_id = GatewayDispatchService.resolve_trace_id(correlation_id, prefix="coord.soul.read")
|
||||
self.logger.log(
|
||||
5,
|
||||
"gateway.coordination.soul_read.start trace_id=%s board_id=%s target_agent_id=%s",
|
||||
@@ -285,7 +259,9 @@ class GatewayCoordinationService(AbstractGatewayMessagingService):
|
||||
target_agent_id,
|
||||
)
|
||||
target = await self._board_agent_or_404(board=board, agent_id=target_agent_id)
|
||||
_gateway, config = await require_gateway_config_for_board(self.session, board)
|
||||
_gateway, config = await GatewayDispatchService(
|
||||
self.session
|
||||
).require_gateway_config_for_board(board)
|
||||
try:
|
||||
|
||||
async def _do_get() -> object:
|
||||
@@ -342,7 +318,9 @@ class GatewayCoordinationService(AbstractGatewayMessagingService):
|
||||
actor_agent_id: UUID,
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
trace_id = resolve_trace_id(correlation_id, prefix="coord.soul.write")
|
||||
trace_id = GatewayDispatchService.resolve_trace_id(
|
||||
correlation_id, prefix="coord.soul.write"
|
||||
)
|
||||
self.logger.log(
|
||||
5,
|
||||
"gateway.coordination.soul_write.start trace_id=%s board_id=%s target_agent_id=%s "
|
||||
@@ -365,7 +343,9 @@ class GatewayCoordinationService(AbstractGatewayMessagingService):
|
||||
self.session.add(target)
|
||||
await self.session.commit()
|
||||
|
||||
_gateway, config = await require_gateway_config_for_board(self.session, board)
|
||||
_gateway, config = await GatewayDispatchService(
|
||||
self.session
|
||||
).require_gateway_config_for_board(board)
|
||||
try:
|
||||
|
||||
async def _do_set() -> object:
|
||||
@@ -434,7 +414,9 @@ class GatewayCoordinationService(AbstractGatewayMessagingService):
|
||||
payload: GatewayMainAskUserRequest,
|
||||
actor_agent: Agent,
|
||||
) -> GatewayMainAskUserResponse:
|
||||
trace_id = resolve_trace_id(payload.correlation_id, prefix="coord.ask_user")
|
||||
trace_id = GatewayDispatchService.resolve_trace_id(
|
||||
payload.correlation_id, prefix="coord.ask_user"
|
||||
)
|
||||
self.logger.log(
|
||||
5,
|
||||
"gateway.coordination.ask_user.start trace_id=%s board_id=%s actor_agent_id=%s",
|
||||
@@ -442,7 +424,9 @@ class GatewayCoordinationService(AbstractGatewayMessagingService):
|
||||
board.id,
|
||||
actor_agent.id,
|
||||
)
|
||||
gateway, config = await require_gateway_config_for_board(self.session, board)
|
||||
gateway, config = await GatewayDispatchService(
|
||||
self.session
|
||||
).require_gateway_config_for_board(board)
|
||||
main_session_key = GatewayAgentIdentity.session_key(gateway)
|
||||
|
||||
correlation = payload.correlation_id.strip() if payload.correlation_id else ""
|
||||
@@ -575,7 +559,9 @@ class GatewayCoordinationService(AbstractGatewayMessagingService):
|
||||
board_id: UUID,
|
||||
payload: GatewayLeadMessageRequest,
|
||||
) -> GatewayLeadMessageResponse:
|
||||
trace_id = resolve_trace_id(payload.correlation_id, prefix="coord.lead_message")
|
||||
trace_id = GatewayDispatchService.resolve_trace_id(
|
||||
payload.correlation_id, prefix="coord.lead_message"
|
||||
)
|
||||
self.logger.log(
|
||||
5,
|
||||
"gateway.coordination.lead_message.start trace_id=%s board_id=%s actor_agent_id=%s",
|
||||
@@ -662,7 +648,9 @@ class GatewayCoordinationService(AbstractGatewayMessagingService):
|
||||
actor_agent: Agent,
|
||||
payload: GatewayLeadBroadcastRequest,
|
||||
) -> GatewayLeadBroadcastResponse:
|
||||
trace_id = resolve_trace_id(payload.correlation_id, prefix="coord.lead_broadcast")
|
||||
trace_id = GatewayDispatchService.resolve_trace_id(
|
||||
payload.correlation_id, prefix="coord.lead_broadcast"
|
||||
)
|
||||
self.logger.log(
|
||||
5,
|
||||
"gateway.coordination.lead_broadcast.start trace_id=%s actor_agent_id=%s",
|
||||
|
||||
Reference in New Issue
Block a user