feat: add approval-task links model and related functionality for task associations
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy import asc, case, func, or_
|
||||
from sqlalchemy import asc, func, or_
|
||||
from sqlmodel import col, select
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
@@ -29,10 +29,16 @@ from app.models.approvals import Approval
|
||||
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus, ApprovalUpdate
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
from app.services.activity_log import record_activity
|
||||
from app.services.approval_task_links import (
|
||||
load_task_ids_by_approval,
|
||||
normalize_task_ids,
|
||||
replace_approval_task_links,
|
||||
task_counts_for_board,
|
||||
)
|
||||
from app.services.openclaw.gateway_dispatch import GatewayDispatchService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
|
||||
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -42,7 +48,6 @@ if TYPE_CHECKING:
|
||||
router = APIRouter(prefix="/boards/{board_id}/approvals", tags=["approvals"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
TASK_ID_KEYS: tuple[str, ...] = ("task_id", "taskId", "taskID")
|
||||
STREAM_POLL_SECONDS = 2
|
||||
STATUS_FILTER_QUERY = Query(default=None, alias="status")
|
||||
SINCE_QUERY = Query(default=None)
|
||||
@@ -53,21 +58,6 @@ SESSION_DEP = Depends(get_session)
|
||||
ACTOR_DEP = Depends(require_admin_or_agent)
|
||||
|
||||
|
||||
def _extract_task_id(payload: dict[str, object] | None) -> UUID | None:
|
||||
if not payload:
|
||||
return None
|
||||
for key in TASK_ID_KEYS:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return UUID(value)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _parse_since(value: str | None) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
@@ -88,17 +78,47 @@ def _approval_updated_at(approval: Approval) -> datetime:
|
||||
return approval.resolved_at or approval.created_at
|
||||
|
||||
|
||||
def _serialize_approval(approval: Approval) -> dict[str, object]:
|
||||
return ApprovalRead.model_validate(
|
||||
approval,
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
async def _approval_task_ids_map(
|
||||
session: AsyncSession,
|
||||
approvals: Sequence[Approval],
|
||||
) -> dict[UUID, list[UUID]]:
|
||||
approval_ids = [approval.id for approval in approvals]
|
||||
mapping = await load_task_ids_by_approval(session, approval_ids=approval_ids)
|
||||
for approval in approvals:
|
||||
if mapping.get(approval.id):
|
||||
continue
|
||||
if approval.task_id is not None:
|
||||
mapping[approval.id] = [approval.task_id]
|
||||
else:
|
||||
mapping[approval.id] = []
|
||||
return mapping
|
||||
|
||||
|
||||
def _approval_to_read(approval: Approval, *, task_ids: list[UUID]) -> ApprovalRead:
|
||||
primary_task_id = task_ids[0] if task_ids else None
|
||||
model = ApprovalRead.model_validate(approval, from_attributes=True)
|
||||
return model.model_copy(update={"task_id": primary_task_id, "task_ids": task_ids})
|
||||
|
||||
|
||||
async def _approval_reads(
|
||||
session: AsyncSession,
|
||||
approvals: Sequence[Approval],
|
||||
) -> list[ApprovalRead]:
|
||||
mapping = await _approval_task_ids_map(session, approvals)
|
||||
return [
|
||||
_approval_to_read(approval, task_ids=mapping.get(approval.id, [])) for approval in approvals
|
||||
]
|
||||
|
||||
|
||||
def _serialize_approval(approval: ApprovalRead) -> dict[str, object]:
|
||||
return approval.model_dump(mode="json")
|
||||
|
||||
|
||||
def _approval_resolution_message(
|
||||
*,
|
||||
board: Board,
|
||||
approval: Approval,
|
||||
task_ids: Sequence[UUID] | None = None,
|
||||
) -> str:
|
||||
status_text = "approved" if approval.status == "approved" else "rejected"
|
||||
lines = [
|
||||
@@ -109,8 +129,13 @@ def _approval_resolution_message(
|
||||
f"Decision: {status_text}",
|
||||
f"Confidence: {approval.confidence}",
|
||||
]
|
||||
if approval.task_id is not None:
|
||||
lines.append(f"Task ID: {approval.task_id}")
|
||||
normalized_task_ids = list(task_ids or [])
|
||||
if not normalized_task_ids and approval.task_id is not None:
|
||||
normalized_task_ids = [approval.task_id]
|
||||
if len(normalized_task_ids) == 1:
|
||||
lines.append(f"Task ID: {normalized_task_ids[0]}")
|
||||
elif normalized_task_ids:
|
||||
lines.append(f"Task IDs: {', '.join(str(value) for value in normalized_task_ids)}")
|
||||
lines.append("")
|
||||
lines.append("Take action: continue execution using the final approval decision.")
|
||||
return "\n".join(lines)
|
||||
@@ -145,7 +170,12 @@ async def _notify_lead_on_approval_resolution(
|
||||
if config is None:
|
||||
return
|
||||
|
||||
message = _approval_resolution_message(board=board, approval=approval)
|
||||
task_ids_by_approval = await load_task_ids_by_approval(session, approval_ids=[approval.id])
|
||||
message = _approval_resolution_message(
|
||||
board=board,
|
||||
approval=approval,
|
||||
task_ids=task_ids_by_approval.get(approval.id, []),
|
||||
)
|
||||
error = await dispatch.try_send_agent_message(
|
||||
session_key=lead.openclaw_session_id,
|
||||
config=config,
|
||||
@@ -202,7 +232,17 @@ async def list_approvals(
|
||||
if status_filter:
|
||||
statement = statement.filter(col(Approval.status) == status_filter)
|
||||
statement = statement.order_by(col(Approval.created_at).desc())
|
||||
return await paginate(session, statement.statement)
|
||||
|
||||
async def _transform(items: Sequence[object]) -> Sequence[ApprovalRead]:
|
||||
approvals: list[Approval] = []
|
||||
for item in items:
|
||||
if not isinstance(item, Approval):
|
||||
msg = "Expected Approval items from approvals pagination query."
|
||||
raise TypeError(msg)
|
||||
approvals.append(item)
|
||||
return await _approval_reads(session, approvals)
|
||||
|
||||
return await paginate(session, statement.statement, transformer=_transform)
|
||||
|
||||
|
||||
@router.get("/stream")
|
||||
@@ -223,6 +263,7 @@ async def stream_approvals(
|
||||
break
|
||||
async with async_session_maker() as session:
|
||||
approvals = await _fetch_approval_events(session, board.id, last_seen)
|
||||
approval_reads = await _approval_reads(session, approvals)
|
||||
pending_approvals_count = int(
|
||||
(
|
||||
await session.exec(
|
||||
@@ -233,50 +274,36 @@ async def stream_approvals(
|
||||
).one(),
|
||||
)
|
||||
task_ids = {
|
||||
approval.task_id for approval in approvals if approval.task_id is not None
|
||||
task_id
|
||||
for approval_read in approval_reads
|
||||
for task_id in approval_read.task_ids
|
||||
}
|
||||
counts_by_task_id: dict[UUID, tuple[int, int]] = {}
|
||||
if task_ids:
|
||||
rows = list(
|
||||
await session.exec(
|
||||
select(
|
||||
col(Approval.task_id),
|
||||
func.count(col(Approval.id)).label("total"),
|
||||
func.sum(
|
||||
case(
|
||||
(col(Approval.status) == "pending", 1),
|
||||
else_=0,
|
||||
),
|
||||
).label("pending"),
|
||||
)
|
||||
.where(col(Approval.board_id) == board.id)
|
||||
.where(col(Approval.task_id).in_(task_ids))
|
||||
.group_by(col(Approval.task_id)),
|
||||
),
|
||||
)
|
||||
for task_id, total, pending in rows:
|
||||
if task_id is None:
|
||||
continue
|
||||
counts_by_task_id[task_id] = (
|
||||
int(total or 0),
|
||||
int(pending or 0),
|
||||
)
|
||||
for approval in approvals:
|
||||
counts_by_task_id = await task_counts_for_board(
|
||||
session,
|
||||
board_id=board.id,
|
||||
task_ids=task_ids,
|
||||
)
|
||||
for approval, approval_read in zip(approvals, approval_reads, strict=True):
|
||||
updated_at = _approval_updated_at(approval)
|
||||
last_seen = max(updated_at, last_seen)
|
||||
payload: dict[str, object] = {
|
||||
"approval": _serialize_approval(approval),
|
||||
"approval": _serialize_approval(approval_read),
|
||||
"pending_approvals_count": pending_approvals_count,
|
||||
}
|
||||
if approval.task_id is not None:
|
||||
counts = counts_by_task_id.get(approval.task_id)
|
||||
if counts is not None:
|
||||
total, pending = counts
|
||||
payload["task_counts"] = {
|
||||
"task_id": str(approval.task_id),
|
||||
"approvals_count": total,
|
||||
"approvals_pending_count": pending,
|
||||
}
|
||||
task_counts = [
|
||||
{
|
||||
"task_id": str(task_id),
|
||||
"approvals_count": total,
|
||||
"approvals_pending_count": pending,
|
||||
}
|
||||
for task_id in approval_read.task_ids
|
||||
if (counts := counts_by_task_id.get(task_id)) is not None
|
||||
for total, pending in [counts]
|
||||
]
|
||||
if len(task_counts) == 1:
|
||||
payload["task_counts"] = task_counts[0]
|
||||
elif task_counts:
|
||||
payload["task_counts"] = task_counts
|
||||
yield {"event": "approval", "data": json.dumps(payload)}
|
||||
await asyncio.sleep(STREAM_POLL_SECONDS)
|
||||
|
||||
@@ -289,9 +316,14 @@ async def create_approval(
|
||||
board: Board = BOARD_WRITE_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
_actor: ActorContext = ACTOR_DEP,
|
||||
) -> Approval:
|
||||
) -> ApprovalRead:
|
||||
"""Create an approval for a board."""
|
||||
task_id = payload.task_id or _extract_task_id(payload.payload)
|
||||
task_ids = normalize_task_ids(
|
||||
task_id=payload.task_id,
|
||||
task_ids=payload.task_ids,
|
||||
payload=payload.payload,
|
||||
)
|
||||
task_id = task_ids[0] if task_ids else None
|
||||
approval = Approval(
|
||||
board_id=board.id,
|
||||
task_id=task_id,
|
||||
@@ -303,9 +335,15 @@ async def create_approval(
|
||||
status=payload.status,
|
||||
)
|
||||
session.add(approval)
|
||||
await session.flush()
|
||||
await replace_approval_task_links(
|
||||
session,
|
||||
approval_id=approval.id,
|
||||
task_ids=task_ids,
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(approval)
|
||||
return approval
|
||||
return _approval_to_read(approval, task_ids=task_ids)
|
||||
|
||||
|
||||
@router.patch("/{approval_id}", response_model=ApprovalRead)
|
||||
@@ -314,7 +352,7 @@ async def update_approval(
|
||||
payload: ApprovalUpdate,
|
||||
board: Board = BOARD_USER_WRITE_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> Approval:
|
||||
) -> ApprovalRead:
|
||||
"""Update an approval's status and resolution timestamp."""
|
||||
approval = await Approval.objects.by_id(approval_id).first(session)
|
||||
if approval is None or approval.board_id != board.id:
|
||||
@@ -342,4 +380,5 @@ async def update_approval(
|
||||
approval.id,
|
||||
approval.status,
|
||||
)
|
||||
return approval
|
||||
reads = await _approval_reads(session, [approval])
|
||||
return reads[0]
|
||||
|
||||
@@ -18,6 +18,7 @@ from app.db.pagination import paginate
|
||||
from app.db.session import get_session
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.approval_task_links import ApprovalTaskLink
|
||||
from app.models.approvals import Approval
|
||||
from app.models.board_group_memory import BoardGroupMemory
|
||||
from app.models.board_groups import BoardGroup
|
||||
@@ -269,6 +270,14 @@ async def delete_my_org(
|
||||
col(TaskFingerprint.board_id).in_(board_ids),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
ApprovalTaskLink,
|
||||
col(ApprovalTaskLink.approval_id).in_(
|
||||
select(Approval.id).where(col(Approval.board_id).in_(board_ids))
|
||||
),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
Approval,
|
||||
|
||||
@@ -29,6 +29,7 @@ from app.db.pagination import paginate
|
||||
from app.db.session import async_session_maker, get_session
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.approval_task_links import ApprovalTaskLink
|
||||
from app.models.approvals import Approval
|
||||
from app.models.boards import Board
|
||||
from app.models.task_dependencies import TaskDependency
|
||||
@@ -39,6 +40,7 @@ from app.schemas.errors import BlockedTaskError
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
|
||||
from app.services.activity_log import record_activity
|
||||
from app.services.approval_task_links import load_task_ids_by_approval
|
||||
from app.services.mentions import extract_mentions, matches_agent_mention
|
||||
from app.services.openclaw.gateway_dispatch import GatewayDispatchService
|
||||
from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig
|
||||
@@ -922,12 +924,26 @@ async def delete_task(
|
||||
col(TaskFingerprint.task_id) == task.id,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
primary_approvals = list(
|
||||
await Approval.objects.filter(col(Approval.task_id) == task.id).all(session),
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
Approval,
|
||||
col(Approval.task_id) == task.id,
|
||||
ApprovalTaskLink,
|
||||
col(ApprovalTaskLink.task_id) == task.id,
|
||||
commit=False,
|
||||
)
|
||||
if primary_approvals:
|
||||
primary_ids = [approval.id for approval in primary_approvals]
|
||||
remaining_by_approval = await load_task_ids_by_approval(session, approval_ids=primary_ids)
|
||||
for approval in primary_approvals:
|
||||
remaining_task_ids = remaining_by_approval.get(approval.id, [])
|
||||
if remaining_task_ids:
|
||||
approval.task_id = remaining_task_ids[0]
|
||||
session.add(approval)
|
||||
continue
|
||||
await session.delete(approval)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
TaskDependency,
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.db import crud
|
||||
from app.db.session import get_session
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.approval_task_links import ApprovalTaskLink
|
||||
from app.models.approvals import Approval
|
||||
from app.models.board_group_memory import BoardGroupMemory
|
||||
from app.models.board_groups import BoardGroup
|
||||
@@ -83,6 +84,14 @@ async def _delete_organization_tree(
|
||||
col(TaskFingerprint.board_id).in_(board_ids),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
ApprovalTaskLink,
|
||||
col(ApprovalTaskLink.approval_id).in_(
|
||||
select(Approval.id).where(col(Approval.board_id).in_(board_ids))
|
||||
),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
Approval,
|
||||
|
||||
Reference in New Issue
Block a user