feat: add approval-task links model and related functionality for task associations

This commit is contained in:
Abhimanyu Saharan
2026-02-11 20:27:04 +05:30
parent 3dfdfa3c3e
commit af8a263c27
19 changed files with 870 additions and 129 deletions

View File

@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, case, func, or_
from sqlalchemy import asc, func, or_
from sqlmodel import col, select
from sse_starlette.sse import EventSourceResponse
@@ -29,10 +29,16 @@ from app.models.approvals import Approval
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus, ApprovalUpdate
from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.activity_log import record_activity
from app.services.approval_task_links import (
load_task_ids_by_approval,
normalize_task_ids,
replace_approval_task_links,
task_counts_for_board,
)
from app.services.openclaw.gateway_dispatch import GatewayDispatchService
if TYPE_CHECKING:
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Sequence
from fastapi_pagination.limit_offset import LimitOffsetPage
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -42,7 +48,6 @@ if TYPE_CHECKING:
router = APIRouter(prefix="/boards/{board_id}/approvals", tags=["approvals"])
logger = get_logger(__name__)
TASK_ID_KEYS: tuple[str, ...] = ("task_id", "taskId", "taskID")
STREAM_POLL_SECONDS = 2
STATUS_FILTER_QUERY = Query(default=None, alias="status")
SINCE_QUERY = Query(default=None)
@@ -53,21 +58,6 @@ SESSION_DEP = Depends(get_session)
ACTOR_DEP = Depends(require_admin_or_agent)
def _extract_task_id(payload: dict[str, object] | None) -> UUID | None:
if not payload:
return None
for key in TASK_ID_KEYS:
value = payload.get(key)
if isinstance(value, UUID):
return value
if isinstance(value, str):
try:
return UUID(value)
except ValueError:
continue
return None
def _parse_since(value: str | None) -> datetime | None:
if not value:
return None
@@ -88,17 +78,47 @@ def _approval_updated_at(approval: Approval) -> datetime:
return approval.resolved_at or approval.created_at
def _serialize_approval(approval: Approval) -> dict[str, object]:
return ApprovalRead.model_validate(
approval,
from_attributes=True,
).model_dump(mode="json")
async def _approval_task_ids_map(
session: AsyncSession,
approvals: Sequence[Approval],
) -> dict[UUID, list[UUID]]:
approval_ids = [approval.id for approval in approvals]
mapping = await load_task_ids_by_approval(session, approval_ids=approval_ids)
for approval in approvals:
if mapping.get(approval.id):
continue
if approval.task_id is not None:
mapping[approval.id] = [approval.task_id]
else:
mapping[approval.id] = []
return mapping
def _approval_to_read(approval: Approval, *, task_ids: list[UUID]) -> ApprovalRead:
primary_task_id = task_ids[0] if task_ids else None
model = ApprovalRead.model_validate(approval, from_attributes=True)
return model.model_copy(update={"task_id": primary_task_id, "task_ids": task_ids})
async def _approval_reads(
session: AsyncSession,
approvals: Sequence[Approval],
) -> list[ApprovalRead]:
mapping = await _approval_task_ids_map(session, approvals)
return [
_approval_to_read(approval, task_ids=mapping.get(approval.id, [])) for approval in approvals
]
def _serialize_approval(approval: ApprovalRead) -> dict[str, object]:
return approval.model_dump(mode="json")
def _approval_resolution_message(
*,
board: Board,
approval: Approval,
task_ids: Sequence[UUID] | None = None,
) -> str:
status_text = "approved" if approval.status == "approved" else "rejected"
lines = [
@@ -109,8 +129,13 @@ def _approval_resolution_message(
f"Decision: {status_text}",
f"Confidence: {approval.confidence}",
]
if approval.task_id is not None:
lines.append(f"Task ID: {approval.task_id}")
normalized_task_ids = list(task_ids or [])
if not normalized_task_ids and approval.task_id is not None:
normalized_task_ids = [approval.task_id]
if len(normalized_task_ids) == 1:
lines.append(f"Task ID: {normalized_task_ids[0]}")
elif normalized_task_ids:
lines.append(f"Task IDs: {', '.join(str(value) for value in normalized_task_ids)}")
lines.append("")
lines.append("Take action: continue execution using the final approval decision.")
return "\n".join(lines)
@@ -145,7 +170,12 @@ async def _notify_lead_on_approval_resolution(
if config is None:
return
message = _approval_resolution_message(board=board, approval=approval)
task_ids_by_approval = await load_task_ids_by_approval(session, approval_ids=[approval.id])
message = _approval_resolution_message(
board=board,
approval=approval,
task_ids=task_ids_by_approval.get(approval.id, []),
)
error = await dispatch.try_send_agent_message(
session_key=lead.openclaw_session_id,
config=config,
@@ -202,7 +232,17 @@ async def list_approvals(
if status_filter:
statement = statement.filter(col(Approval.status) == status_filter)
statement = statement.order_by(col(Approval.created_at).desc())
return await paginate(session, statement.statement)
async def _transform(items: Sequence[object]) -> Sequence[ApprovalRead]:
approvals: list[Approval] = []
for item in items:
if not isinstance(item, Approval):
msg = "Expected Approval items from approvals pagination query."
raise TypeError(msg)
approvals.append(item)
return await _approval_reads(session, approvals)
return await paginate(session, statement.statement, transformer=_transform)
@router.get("/stream")
@@ -223,6 +263,7 @@ async def stream_approvals(
break
async with async_session_maker() as session:
approvals = await _fetch_approval_events(session, board.id, last_seen)
approval_reads = await _approval_reads(session, approvals)
pending_approvals_count = int(
(
await session.exec(
@@ -233,50 +274,36 @@ async def stream_approvals(
).one(),
)
task_ids = {
approval.task_id for approval in approvals if approval.task_id is not None
task_id
for approval_read in approval_reads
for task_id in approval_read.task_ids
}
counts_by_task_id: dict[UUID, tuple[int, int]] = {}
if task_ids:
rows = list(
await session.exec(
select(
col(Approval.task_id),
func.count(col(Approval.id)).label("total"),
func.sum(
case(
(col(Approval.status) == "pending", 1),
else_=0,
),
).label("pending"),
)
.where(col(Approval.board_id) == board.id)
.where(col(Approval.task_id).in_(task_ids))
.group_by(col(Approval.task_id)),
),
)
for task_id, total, pending in rows:
if task_id is None:
continue
counts_by_task_id[task_id] = (
int(total or 0),
int(pending or 0),
)
for approval in approvals:
counts_by_task_id = await task_counts_for_board(
session,
board_id=board.id,
task_ids=task_ids,
)
for approval, approval_read in zip(approvals, approval_reads, strict=True):
updated_at = _approval_updated_at(approval)
last_seen = max(updated_at, last_seen)
payload: dict[str, object] = {
"approval": _serialize_approval(approval),
"approval": _serialize_approval(approval_read),
"pending_approvals_count": pending_approvals_count,
}
if approval.task_id is not None:
counts = counts_by_task_id.get(approval.task_id)
if counts is not None:
total, pending = counts
payload["task_counts"] = {
"task_id": str(approval.task_id),
"approvals_count": total,
"approvals_pending_count": pending,
}
task_counts = [
{
"task_id": str(task_id),
"approvals_count": total,
"approvals_pending_count": pending,
}
for task_id in approval_read.task_ids
if (counts := counts_by_task_id.get(task_id)) is not None
for total, pending in [counts]
]
if len(task_counts) == 1:
payload["task_counts"] = task_counts[0]
elif task_counts:
payload["task_counts"] = task_counts
yield {"event": "approval", "data": json.dumps(payload)}
await asyncio.sleep(STREAM_POLL_SECONDS)
@@ -289,9 +316,14 @@ async def create_approval(
board: Board = BOARD_WRITE_DEP,
session: AsyncSession = SESSION_DEP,
_actor: ActorContext = ACTOR_DEP,
) -> Approval:
) -> ApprovalRead:
"""Create an approval for a board."""
task_id = payload.task_id or _extract_task_id(payload.payload)
task_ids = normalize_task_ids(
task_id=payload.task_id,
task_ids=payload.task_ids,
payload=payload.payload,
)
task_id = task_ids[0] if task_ids else None
approval = Approval(
board_id=board.id,
task_id=task_id,
@@ -303,9 +335,15 @@ async def create_approval(
status=payload.status,
)
session.add(approval)
await session.flush()
await replace_approval_task_links(
session,
approval_id=approval.id,
task_ids=task_ids,
)
await session.commit()
await session.refresh(approval)
return approval
return _approval_to_read(approval, task_ids=task_ids)
@router.patch("/{approval_id}", response_model=ApprovalRead)
@@ -314,7 +352,7 @@ async def update_approval(
payload: ApprovalUpdate,
board: Board = BOARD_USER_WRITE_DEP,
session: AsyncSession = SESSION_DEP,
) -> Approval:
) -> ApprovalRead:
"""Update an approval's status and resolution timestamp."""
approval = await Approval.objects.by_id(approval_id).first(session)
if approval is None or approval.board_id != board.id:
@@ -342,4 +380,5 @@ async def update_approval(
approval.id,
approval.status,
)
return approval
reads = await _approval_reads(session, [approval])
return reads[0]

View File

@@ -18,6 +18,7 @@ from app.db.pagination import paginate
from app.db.session import get_session
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent
from app.models.approval_task_links import ApprovalTaskLink
from app.models.approvals import Approval
from app.models.board_group_memory import BoardGroupMemory
from app.models.board_groups import BoardGroup
@@ -269,6 +270,14 @@ async def delete_my_org(
col(TaskFingerprint.board_id).in_(board_ids),
commit=False,
)
await crud.delete_where(
session,
ApprovalTaskLink,
col(ApprovalTaskLink.approval_id).in_(
select(Approval.id).where(col(Approval.board_id).in_(board_ids))
),
commit=False,
)
await crud.delete_where(
session,
Approval,

View File

@@ -29,6 +29,7 @@ from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent
from app.models.approval_task_links import ApprovalTaskLink
from app.models.approvals import Approval
from app.models.boards import Board
from app.models.task_dependencies import TaskDependency
@@ -39,6 +40,7 @@ from app.schemas.errors import BlockedTaskError
from app.schemas.pagination import DefaultLimitOffsetPage
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
from app.services.activity_log import record_activity
from app.services.approval_task_links import load_task_ids_by_approval
from app.services.mentions import extract_mentions, matches_agent_mention
from app.services.openclaw.gateway_dispatch import GatewayDispatchService
from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig
@@ -922,12 +924,26 @@ async def delete_task(
col(TaskFingerprint.task_id) == task.id,
commit=False,
)
primary_approvals = list(
await Approval.objects.filter(col(Approval.task_id) == task.id).all(session),
)
await crud.delete_where(
session,
Approval,
col(Approval.task_id) == task.id,
ApprovalTaskLink,
col(ApprovalTaskLink.task_id) == task.id,
commit=False,
)
if primary_approvals:
primary_ids = [approval.id for approval in primary_approvals]
remaining_by_approval = await load_task_ids_by_approval(session, approval_ids=primary_ids)
for approval in primary_approvals:
remaining_task_ids = remaining_by_approval.get(approval.id, [])
if remaining_task_ids:
approval.task_id = remaining_task_ids[0]
session.add(approval)
continue
await session.delete(approval)
await crud.delete_where(
session,
TaskDependency,

View File

@@ -13,6 +13,7 @@ from app.db import crud
from app.db.session import get_session
from app.models.activity_events import ActivityEvent
from app.models.agents import Agent
from app.models.approval_task_links import ApprovalTaskLink
from app.models.approvals import Approval
from app.models.board_group_memory import BoardGroupMemory
from app.models.board_groups import BoardGroup
@@ -83,6 +84,14 @@ async def _delete_organization_tree(
col(TaskFingerprint.board_id).in_(board_ids),
commit=False,
)
await crud.delete_where(
session,
ApprovalTaskLink,
col(ApprovalTaskLink.approval_id).in_(
select(Approval.id).where(col(Approval.board_id).in_(board_ids))
),
commit=False,
)
await crud.delete_where(
session,
Approval,