feat: add approval-task links model and related functionality for task associations
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy import asc, case, func, or_
|
||||
from sqlalchemy import asc, func, or_
|
||||
from sqlmodel import col, select
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
@@ -29,10 +29,16 @@ from app.models.approvals import Approval
|
||||
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus, ApprovalUpdate
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
from app.services.activity_log import record_activity
|
||||
from app.services.approval_task_links import (
|
||||
load_task_ids_by_approval,
|
||||
normalize_task_ids,
|
||||
replace_approval_task_links,
|
||||
task_counts_for_board,
|
||||
)
|
||||
from app.services.openclaw.gateway_dispatch import GatewayDispatchService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
|
||||
from fastapi_pagination.limit_offset import LimitOffsetPage
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -42,7 +48,6 @@ if TYPE_CHECKING:
|
||||
router = APIRouter(prefix="/boards/{board_id}/approvals", tags=["approvals"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
TASK_ID_KEYS: tuple[str, ...] = ("task_id", "taskId", "taskID")
|
||||
STREAM_POLL_SECONDS = 2
|
||||
STATUS_FILTER_QUERY = Query(default=None, alias="status")
|
||||
SINCE_QUERY = Query(default=None)
|
||||
@@ -53,21 +58,6 @@ SESSION_DEP = Depends(get_session)
|
||||
ACTOR_DEP = Depends(require_admin_or_agent)
|
||||
|
||||
|
||||
def _extract_task_id(payload: dict[str, object] | None) -> UUID | None:
|
||||
if not payload:
|
||||
return None
|
||||
for key in TASK_ID_KEYS:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return UUID(value)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _parse_since(value: str | None) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
@@ -88,17 +78,47 @@ def _approval_updated_at(approval: Approval) -> datetime:
|
||||
return approval.resolved_at or approval.created_at
|
||||
|
||||
|
||||
def _serialize_approval(approval: Approval) -> dict[str, object]:
|
||||
return ApprovalRead.model_validate(
|
||||
approval,
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
async def _approval_task_ids_map(
|
||||
session: AsyncSession,
|
||||
approvals: Sequence[Approval],
|
||||
) -> dict[UUID, list[UUID]]:
|
||||
approval_ids = [approval.id for approval in approvals]
|
||||
mapping = await load_task_ids_by_approval(session, approval_ids=approval_ids)
|
||||
for approval in approvals:
|
||||
if mapping.get(approval.id):
|
||||
continue
|
||||
if approval.task_id is not None:
|
||||
mapping[approval.id] = [approval.task_id]
|
||||
else:
|
||||
mapping[approval.id] = []
|
||||
return mapping
|
||||
|
||||
|
||||
def _approval_to_read(approval: Approval, *, task_ids: list[UUID]) -> ApprovalRead:
|
||||
primary_task_id = task_ids[0] if task_ids else None
|
||||
model = ApprovalRead.model_validate(approval, from_attributes=True)
|
||||
return model.model_copy(update={"task_id": primary_task_id, "task_ids": task_ids})
|
||||
|
||||
|
||||
async def _approval_reads(
|
||||
session: AsyncSession,
|
||||
approvals: Sequence[Approval],
|
||||
) -> list[ApprovalRead]:
|
||||
mapping = await _approval_task_ids_map(session, approvals)
|
||||
return [
|
||||
_approval_to_read(approval, task_ids=mapping.get(approval.id, [])) for approval in approvals
|
||||
]
|
||||
|
||||
|
||||
def _serialize_approval(approval: ApprovalRead) -> dict[str, object]:
|
||||
return approval.model_dump(mode="json")
|
||||
|
||||
|
||||
def _approval_resolution_message(
|
||||
*,
|
||||
board: Board,
|
||||
approval: Approval,
|
||||
task_ids: Sequence[UUID] | None = None,
|
||||
) -> str:
|
||||
status_text = "approved" if approval.status == "approved" else "rejected"
|
||||
lines = [
|
||||
@@ -109,8 +129,13 @@ def _approval_resolution_message(
|
||||
f"Decision: {status_text}",
|
||||
f"Confidence: {approval.confidence}",
|
||||
]
|
||||
if approval.task_id is not None:
|
||||
lines.append(f"Task ID: {approval.task_id}")
|
||||
normalized_task_ids = list(task_ids or [])
|
||||
if not normalized_task_ids and approval.task_id is not None:
|
||||
normalized_task_ids = [approval.task_id]
|
||||
if len(normalized_task_ids) == 1:
|
||||
lines.append(f"Task ID: {normalized_task_ids[0]}")
|
||||
elif normalized_task_ids:
|
||||
lines.append(f"Task IDs: {', '.join(str(value) for value in normalized_task_ids)}")
|
||||
lines.append("")
|
||||
lines.append("Take action: continue execution using the final approval decision.")
|
||||
return "\n".join(lines)
|
||||
@@ -145,7 +170,12 @@ async def _notify_lead_on_approval_resolution(
|
||||
if config is None:
|
||||
return
|
||||
|
||||
message = _approval_resolution_message(board=board, approval=approval)
|
||||
task_ids_by_approval = await load_task_ids_by_approval(session, approval_ids=[approval.id])
|
||||
message = _approval_resolution_message(
|
||||
board=board,
|
||||
approval=approval,
|
||||
task_ids=task_ids_by_approval.get(approval.id, []),
|
||||
)
|
||||
error = await dispatch.try_send_agent_message(
|
||||
session_key=lead.openclaw_session_id,
|
||||
config=config,
|
||||
@@ -202,7 +232,17 @@ async def list_approvals(
|
||||
if status_filter:
|
||||
statement = statement.filter(col(Approval.status) == status_filter)
|
||||
statement = statement.order_by(col(Approval.created_at).desc())
|
||||
return await paginate(session, statement.statement)
|
||||
|
||||
async def _transform(items: Sequence[object]) -> Sequence[ApprovalRead]:
|
||||
approvals: list[Approval] = []
|
||||
for item in items:
|
||||
if not isinstance(item, Approval):
|
||||
msg = "Expected Approval items from approvals pagination query."
|
||||
raise TypeError(msg)
|
||||
approvals.append(item)
|
||||
return await _approval_reads(session, approvals)
|
||||
|
||||
return await paginate(session, statement.statement, transformer=_transform)
|
||||
|
||||
|
||||
@router.get("/stream")
|
||||
@@ -223,6 +263,7 @@ async def stream_approvals(
|
||||
break
|
||||
async with async_session_maker() as session:
|
||||
approvals = await _fetch_approval_events(session, board.id, last_seen)
|
||||
approval_reads = await _approval_reads(session, approvals)
|
||||
pending_approvals_count = int(
|
||||
(
|
||||
await session.exec(
|
||||
@@ -233,50 +274,36 @@ async def stream_approvals(
|
||||
).one(),
|
||||
)
|
||||
task_ids = {
|
||||
approval.task_id for approval in approvals if approval.task_id is not None
|
||||
task_id
|
||||
for approval_read in approval_reads
|
||||
for task_id in approval_read.task_ids
|
||||
}
|
||||
counts_by_task_id: dict[UUID, tuple[int, int]] = {}
|
||||
if task_ids:
|
||||
rows = list(
|
||||
await session.exec(
|
||||
select(
|
||||
col(Approval.task_id),
|
||||
func.count(col(Approval.id)).label("total"),
|
||||
func.sum(
|
||||
case(
|
||||
(col(Approval.status) == "pending", 1),
|
||||
else_=0,
|
||||
),
|
||||
).label("pending"),
|
||||
)
|
||||
.where(col(Approval.board_id) == board.id)
|
||||
.where(col(Approval.task_id).in_(task_ids))
|
||||
.group_by(col(Approval.task_id)),
|
||||
),
|
||||
)
|
||||
for task_id, total, pending in rows:
|
||||
if task_id is None:
|
||||
continue
|
||||
counts_by_task_id[task_id] = (
|
||||
int(total or 0),
|
||||
int(pending or 0),
|
||||
)
|
||||
for approval in approvals:
|
||||
counts_by_task_id = await task_counts_for_board(
|
||||
session,
|
||||
board_id=board.id,
|
||||
task_ids=task_ids,
|
||||
)
|
||||
for approval, approval_read in zip(approvals, approval_reads, strict=True):
|
||||
updated_at = _approval_updated_at(approval)
|
||||
last_seen = max(updated_at, last_seen)
|
||||
payload: dict[str, object] = {
|
||||
"approval": _serialize_approval(approval),
|
||||
"approval": _serialize_approval(approval_read),
|
||||
"pending_approvals_count": pending_approvals_count,
|
||||
}
|
||||
if approval.task_id is not None:
|
||||
counts = counts_by_task_id.get(approval.task_id)
|
||||
if counts is not None:
|
||||
total, pending = counts
|
||||
payload["task_counts"] = {
|
||||
"task_id": str(approval.task_id),
|
||||
"approvals_count": total,
|
||||
"approvals_pending_count": pending,
|
||||
}
|
||||
task_counts = [
|
||||
{
|
||||
"task_id": str(task_id),
|
||||
"approvals_count": total,
|
||||
"approvals_pending_count": pending,
|
||||
}
|
||||
for task_id in approval_read.task_ids
|
||||
if (counts := counts_by_task_id.get(task_id)) is not None
|
||||
for total, pending in [counts]
|
||||
]
|
||||
if len(task_counts) == 1:
|
||||
payload["task_counts"] = task_counts[0]
|
||||
elif task_counts:
|
||||
payload["task_counts"] = task_counts
|
||||
yield {"event": "approval", "data": json.dumps(payload)}
|
||||
await asyncio.sleep(STREAM_POLL_SECONDS)
|
||||
|
||||
@@ -289,9 +316,14 @@ async def create_approval(
|
||||
board: Board = BOARD_WRITE_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
_actor: ActorContext = ACTOR_DEP,
|
||||
) -> Approval:
|
||||
) -> ApprovalRead:
|
||||
"""Create an approval for a board."""
|
||||
task_id = payload.task_id or _extract_task_id(payload.payload)
|
||||
task_ids = normalize_task_ids(
|
||||
task_id=payload.task_id,
|
||||
task_ids=payload.task_ids,
|
||||
payload=payload.payload,
|
||||
)
|
||||
task_id = task_ids[0] if task_ids else None
|
||||
approval = Approval(
|
||||
board_id=board.id,
|
||||
task_id=task_id,
|
||||
@@ -303,9 +335,15 @@ async def create_approval(
|
||||
status=payload.status,
|
||||
)
|
||||
session.add(approval)
|
||||
await session.flush()
|
||||
await replace_approval_task_links(
|
||||
session,
|
||||
approval_id=approval.id,
|
||||
task_ids=task_ids,
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(approval)
|
||||
return approval
|
||||
return _approval_to_read(approval, task_ids=task_ids)
|
||||
|
||||
|
||||
@router.patch("/{approval_id}", response_model=ApprovalRead)
|
||||
@@ -314,7 +352,7 @@ async def update_approval(
|
||||
payload: ApprovalUpdate,
|
||||
board: Board = BOARD_USER_WRITE_DEP,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
) -> Approval:
|
||||
) -> ApprovalRead:
|
||||
"""Update an approval's status and resolution timestamp."""
|
||||
approval = await Approval.objects.by_id(approval_id).first(session)
|
||||
if approval is None or approval.board_id != board.id:
|
||||
@@ -342,4 +380,5 @@ async def update_approval(
|
||||
approval.id,
|
||||
approval.status,
|
||||
)
|
||||
return approval
|
||||
reads = await _approval_reads(session, [approval])
|
||||
return reads[0]
|
||||
|
||||
@@ -18,6 +18,7 @@ from app.db.pagination import paginate
|
||||
from app.db.session import get_session
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.approval_task_links import ApprovalTaskLink
|
||||
from app.models.approvals import Approval
|
||||
from app.models.board_group_memory import BoardGroupMemory
|
||||
from app.models.board_groups import BoardGroup
|
||||
@@ -269,6 +270,14 @@ async def delete_my_org(
|
||||
col(TaskFingerprint.board_id).in_(board_ids),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
ApprovalTaskLink,
|
||||
col(ApprovalTaskLink.approval_id).in_(
|
||||
select(Approval.id).where(col(Approval.board_id).in_(board_ids))
|
||||
),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
Approval,
|
||||
|
||||
@@ -29,6 +29,7 @@ from app.db.pagination import paginate
|
||||
from app.db.session import async_session_maker, get_session
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.approval_task_links import ApprovalTaskLink
|
||||
from app.models.approvals import Approval
|
||||
from app.models.boards import Board
|
||||
from app.models.task_dependencies import TaskDependency
|
||||
@@ -39,6 +40,7 @@ from app.schemas.errors import BlockedTaskError
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
from app.schemas.tasks import TaskCommentCreate, TaskCommentRead, TaskCreate, TaskRead, TaskUpdate
|
||||
from app.services.activity_log import record_activity
|
||||
from app.services.approval_task_links import load_task_ids_by_approval
|
||||
from app.services.mentions import extract_mentions, matches_agent_mention
|
||||
from app.services.openclaw.gateway_dispatch import GatewayDispatchService
|
||||
from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig
|
||||
@@ -922,12 +924,26 @@ async def delete_task(
|
||||
col(TaskFingerprint.task_id) == task.id,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
primary_approvals = list(
|
||||
await Approval.objects.filter(col(Approval.task_id) == task.id).all(session),
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
Approval,
|
||||
col(Approval.task_id) == task.id,
|
||||
ApprovalTaskLink,
|
||||
col(ApprovalTaskLink.task_id) == task.id,
|
||||
commit=False,
|
||||
)
|
||||
if primary_approvals:
|
||||
primary_ids = [approval.id for approval in primary_approvals]
|
||||
remaining_by_approval = await load_task_ids_by_approval(session, approval_ids=primary_ids)
|
||||
for approval in primary_approvals:
|
||||
remaining_task_ids = remaining_by_approval.get(approval.id, [])
|
||||
if remaining_task_ids:
|
||||
approval.task_id = remaining_task_ids[0]
|
||||
session.add(approval)
|
||||
continue
|
||||
await session.delete(approval)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
TaskDependency,
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.db import crud
|
||||
from app.db.session import get_session
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.approval_task_links import ApprovalTaskLink
|
||||
from app.models.approvals import Approval
|
||||
from app.models.board_group_memory import BoardGroupMemory
|
||||
from app.models.board_groups import BoardGroup
|
||||
@@ -83,6 +84,14 @@ async def _delete_organization_tree(
|
||||
col(TaskFingerprint.board_id).in_(board_ids),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
ApprovalTaskLink,
|
||||
col(ApprovalTaskLink.approval_id).in_(
|
||||
select(Approval.id).where(col(Approval.board_id).in_(board_ids))
|
||||
),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
Approval,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.approval_task_links import ApprovalTaskLink
|
||||
from app.models.approvals import Approval
|
||||
from app.models.board_group_memory import BoardGroupMemory
|
||||
from app.models.board_groups import BoardGroup
|
||||
@@ -22,6 +23,7 @@ from app.models.users import User
|
||||
__all__ = [
|
||||
"ActivityEvent",
|
||||
"Agent",
|
||||
"ApprovalTaskLink",
|
||||
"Approval",
|
||||
"BoardGroupMemory",
|
||||
"BoardMemory",
|
||||
|
||||
32
backend/app/models/approval_task_links.py
Normal file
32
backend/app/models/approval_task_links.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Approval-task link model for many-to-many approval associations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlmodel import Field
|
||||
|
||||
from app.core.time import utcnow
|
||||
from app.models.base import QueryModel
|
||||
|
||||
RUNTIME_ANNOTATION_TYPES = (datetime,)
|
||||
|
||||
|
||||
class ApprovalTaskLink(QueryModel, table=True):
|
||||
"""Map an approval request to one task (many links per approval allowed)."""
|
||||
|
||||
__tablename__ = "approval_task_links" # pyright: ignore[reportAssignmentType]
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"approval_id",
|
||||
"task_id",
|
||||
name="uq_approval_task_links_approval_id_task_id",
|
||||
),
|
||||
)
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
approval_id: UUID = Field(foreign_key="approvals.id", index=True)
|
||||
task_id: UUID = Field(foreign_key="tasks.id", index=True)
|
||||
created_at: datetime = Field(default_factory=utcnow)
|
||||
@@ -7,7 +7,7 @@ from typing import Literal, Self
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import model_validator
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
ApprovalStatus = Literal["pending", "approved", "rejected"]
|
||||
STATUS_REQUIRED_ERROR = "status is required"
|
||||
@@ -19,11 +19,29 @@ class ApprovalBase(SQLModel):
|
||||
|
||||
action_type: str
|
||||
task_id: UUID | None = None
|
||||
task_ids: list[UUID] = Field(default_factory=list)
|
||||
payload: dict[str, object] | None = None
|
||||
confidence: int
|
||||
rubric_scores: dict[str, int] | None = None
|
||||
status: ApprovalStatus = "pending"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def normalize_task_links(self) -> Self:
|
||||
"""Keep task identifiers deduplicated and task_id aligned with task_ids."""
|
||||
deduped: list[UUID] = []
|
||||
seen: set[UUID] = set()
|
||||
if self.task_id is not None:
|
||||
deduped.append(self.task_id)
|
||||
seen.add(self.task_id)
|
||||
for task_id in self.task_ids:
|
||||
if task_id in seen:
|
||||
continue
|
||||
seen.add(task_id)
|
||||
deduped.append(task_id)
|
||||
self.task_ids = deduped
|
||||
self.task_id = deduped[0] if deduped else None
|
||||
return self
|
||||
|
||||
|
||||
class ApprovalCreate(ApprovalBase):
|
||||
"""Payload for creating a new approval request."""
|
||||
|
||||
190
backend/app/services/approval_task_links.py
Normal file
190
backend/app/services/approval_task_links.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Helpers for normalizing and querying approval-task associations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import case, delete, exists, func
|
||||
from sqlmodel import col, select
|
||||
|
||||
from app.models.approval_task_links import ApprovalTaskLink
|
||||
from app.models.approvals import Approval
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
TASK_ID_KEYS: tuple[str, ...] = ("task_id", "taskId", "taskID")
|
||||
TASK_IDS_KEYS: tuple[str, ...] = ("task_ids", "taskIds", "taskIDs")
|
||||
|
||||
|
||||
def _coerce_uuid(value: object) -> UUID | None:
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return UUID(value)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def extract_task_ids(payload: dict[str, object] | None) -> list[UUID]:
|
||||
"""Extract task UUIDs from approval payload aliases."""
|
||||
if not payload:
|
||||
return []
|
||||
|
||||
collected: list[UUID] = []
|
||||
for key in TASK_IDS_KEYS:
|
||||
raw = payload.get(key)
|
||||
if isinstance(raw, Sequence) and not isinstance(raw, (str, bytes, bytearray)):
|
||||
for item in raw:
|
||||
task_id = _coerce_uuid(item)
|
||||
if task_id is not None:
|
||||
collected.append(task_id)
|
||||
for key in TASK_ID_KEYS:
|
||||
task_id = _coerce_uuid(payload.get(key))
|
||||
if task_id is not None:
|
||||
collected.append(task_id)
|
||||
|
||||
deduped: list[UUID] = []
|
||||
seen: set[UUID] = set()
|
||||
for task_id in collected:
|
||||
if task_id in seen:
|
||||
continue
|
||||
seen.add(task_id)
|
||||
deduped.append(task_id)
|
||||
return deduped
|
||||
|
||||
|
||||
def normalize_task_ids(
|
||||
*,
|
||||
task_id: UUID | None,
|
||||
task_ids: Sequence[UUID],
|
||||
payload: dict[str, object] | None,
|
||||
) -> list[UUID]:
|
||||
"""Merge explicit and payload-provided task references into an ordered unique list."""
|
||||
merged: list[UUID] = []
|
||||
merged.extend(task_ids)
|
||||
if task_id is not None:
|
||||
merged.append(task_id)
|
||||
merged.extend(extract_task_ids(payload))
|
||||
|
||||
deduped: list[UUID] = []
|
||||
seen: set[UUID] = set()
|
||||
for value in merged:
|
||||
if value in seen:
|
||||
continue
|
||||
seen.add(value)
|
||||
deduped.append(value)
|
||||
return deduped
|
||||
|
||||
|
||||
async def load_task_ids_by_approval(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
approval_ids: Iterable[UUID],
|
||||
) -> dict[UUID, list[UUID]]:
|
||||
"""Return task ids grouped by approval id in insertion order."""
|
||||
ids = list({*approval_ids})
|
||||
if not ids:
|
||||
return {}
|
||||
|
||||
rows = list(
|
||||
await session.exec(
|
||||
select(col(ApprovalTaskLink.approval_id), col(ApprovalTaskLink.task_id))
|
||||
.where(col(ApprovalTaskLink.approval_id).in_(ids))
|
||||
.order_by(col(ApprovalTaskLink.created_at).asc()),
|
||||
),
|
||||
)
|
||||
|
||||
mapping: dict[UUID, list[UUID]] = {approval_id: [] for approval_id in ids}
|
||||
for approval_id, task_id in rows:
|
||||
mapping.setdefault(approval_id, []).append(task_id)
|
||||
return mapping
|
||||
|
||||
|
||||
async def replace_approval_task_links(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
approval_id: UUID,
|
||||
task_ids: Sequence[UUID],
|
||||
) -> None:
|
||||
"""Replace approval-task link rows for an approval id."""
|
||||
await session.exec(
|
||||
delete(ApprovalTaskLink).where(
|
||||
col(ApprovalTaskLink.approval_id) == approval_id,
|
||||
),
|
||||
)
|
||||
for task_id in task_ids:
|
||||
session.add(ApprovalTaskLink(approval_id=approval_id, task_id=task_id))
|
||||
|
||||
|
||||
async def task_counts_for_board(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
board_id: UUID,
|
||||
task_ids: set[UUID] | None = None,
|
||||
) -> dict[UUID, tuple[int, int]]:
|
||||
"""Compute total/pending approval counts per task across all linked tasks on a board."""
|
||||
|
||||
link_statement = (
|
||||
select(
|
||||
col(ApprovalTaskLink.task_id),
|
||||
func.count(col(Approval.id)).label("total"),
|
||||
func.sum(
|
||||
case(
|
||||
(col(Approval.status) == "pending", 1),
|
||||
else_=0,
|
||||
),
|
||||
).label("pending"),
|
||||
)
|
||||
.join(Approval, col(Approval.id) == col(ApprovalTaskLink.approval_id))
|
||||
.where(col(Approval.board_id) == board_id)
|
||||
)
|
||||
if task_ids is not None:
|
||||
if not task_ids:
|
||||
return {}
|
||||
link_statement = link_statement.where(col(ApprovalTaskLink.task_id).in_(task_ids))
|
||||
link_statement = link_statement.group_by(col(ApprovalTaskLink.task_id))
|
||||
|
||||
counts: dict[UUID, tuple[int, int]] = {}
|
||||
for task_id, total, pending in list(await session.exec(link_statement)):
|
||||
counts[task_id] = (int(total or 0), int(pending or 0))
|
||||
|
||||
# Backward compatibility: include legacy rows that have task_id set but no link rows.
|
||||
legacy_statement = (
|
||||
select(
|
||||
col(Approval.task_id),
|
||||
func.count(col(Approval.id)).label("total"),
|
||||
func.sum(
|
||||
case(
|
||||
(col(Approval.status) == "pending", 1),
|
||||
else_=0,
|
||||
),
|
||||
).label("pending"),
|
||||
)
|
||||
.where(col(Approval.board_id) == board_id)
|
||||
.where(col(Approval.task_id).is_not(None))
|
||||
.where(
|
||||
~exists(
|
||||
select(1)
|
||||
.where(col(ApprovalTaskLink.approval_id) == col(Approval.id))
|
||||
.correlate(Approval),
|
||||
),
|
||||
)
|
||||
)
|
||||
if task_ids is not None:
|
||||
legacy_statement = legacy_statement.where(col(Approval.task_id).in_(task_ids))
|
||||
legacy_statement = legacy_statement.group_by(col(Approval.task_id))
|
||||
|
||||
for legacy_task_id, total, pending in list(await session.exec(legacy_statement)):
|
||||
if legacy_task_id is None:
|
||||
continue
|
||||
previous = counts.get(legacy_task_id, (0, 0))
|
||||
counts[legacy_task_id] = (
|
||||
previous[0] + int(total or 0),
|
||||
previous[1] + int(pending or 0),
|
||||
)
|
||||
return counts
|
||||
@@ -14,6 +14,7 @@ from sqlmodel import col, select
|
||||
from app.db import crud
|
||||
from app.models.activity_events import ActivityEvent
|
||||
from app.models.agents import Agent
|
||||
from app.models.approval_task_links import ApprovalTaskLink
|
||||
from app.models.approvals import Approval
|
||||
from app.models.board_memory import BoardMemory
|
||||
from app.models.board_onboarding import BoardOnboardingSession
|
||||
@@ -73,6 +74,13 @@ async def delete_board(session: AsyncSession, *, board: Board) -> OkResponse:
|
||||
)
|
||||
|
||||
# Approvals can reference tasks and agents, so delete before both.
|
||||
approval_ids = select(Approval.id).where(col(Approval.board_id) == board.id)
|
||||
await crud.delete_where(
|
||||
session,
|
||||
ApprovalTaskLink,
|
||||
col(ApprovalTaskLink.approval_id).in_(approval_ids),
|
||||
commit=False,
|
||||
)
|
||||
await crud.delete_where(session, Approval, col(Approval.board_id) == board.id)
|
||||
|
||||
await crud.delete_where(session, BoardMemory, col(BoardMemory.board_id) == board.id)
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import case, func
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
|
||||
from app.models.agents import Agent
|
||||
@@ -15,6 +15,7 @@ from app.schemas.approvals import ApprovalRead
|
||||
from app.schemas.board_memory import BoardMemoryRead
|
||||
from app.schemas.boards import BoardRead
|
||||
from app.schemas.view_models import BoardSnapshot, TaskCardRead
|
||||
from app.services.approval_task_links import load_task_ids_by_approval, task_counts_for_board
|
||||
from app.services.openclaw.provisioning_db import AgentLifecycleService
|
||||
from app.services.task_dependencies import (
|
||||
blocked_by_dependency_ids,
|
||||
@@ -34,8 +35,10 @@ def _memory_to_read(memory: BoardMemory) -> BoardMemoryRead:
|
||||
return BoardMemoryRead.model_validate(memory, from_attributes=True)
|
||||
|
||||
|
||||
def _approval_to_read(approval: Approval) -> ApprovalRead:
|
||||
return ApprovalRead.model_validate(approval, from_attributes=True)
|
||||
def _approval_to_read(approval: Approval, *, task_ids: list[UUID]) -> ApprovalRead:
|
||||
model = ApprovalRead.model_validate(approval, from_attributes=True)
|
||||
primary_task_id = task_ids[0] if task_ids else None
|
||||
return model.model_copy(update={"task_id": primary_task_id, "task_ids": task_ids})
|
||||
|
||||
|
||||
def _task_to_card(
|
||||
@@ -120,27 +123,23 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
|
||||
.limit(200)
|
||||
.all(session)
|
||||
)
|
||||
approval_reads = [_approval_to_read(approval) for approval in approvals]
|
||||
|
||||
counts_by_task_id: dict[UUID, tuple[int, int]] = {}
|
||||
rows = list(
|
||||
await session.exec(
|
||||
select(
|
||||
col(Approval.task_id),
|
||||
func.count(col(Approval.id)).label("total"),
|
||||
func.sum(
|
||||
case((col(Approval.status) == "pending", 1), else_=0),
|
||||
).label("pending"),
|
||||
)
|
||||
.where(col(Approval.board_id) == board.id)
|
||||
.where(col(Approval.task_id).is_not(None))
|
||||
.group_by(col(Approval.task_id)),
|
||||
),
|
||||
approval_ids = [approval.id for approval in approvals]
|
||||
task_ids_by_approval = await load_task_ids_by_approval(
|
||||
session,
|
||||
approval_ids=approval_ids,
|
||||
)
|
||||
for task_id, total, pending in rows:
|
||||
if task_id is None:
|
||||
continue
|
||||
counts_by_task_id[task_id] = (int(total or 0), int(pending or 0))
|
||||
approval_reads = [
|
||||
_approval_to_read(
|
||||
approval,
|
||||
task_ids=task_ids_by_approval.get(
|
||||
approval.id,
|
||||
[approval.task_id] if approval.task_id is not None else [],
|
||||
),
|
||||
)
|
||||
for approval in approvals
|
||||
]
|
||||
|
||||
counts_by_task_id = await task_counts_for_board(session, board_id=board.id)
|
||||
|
||||
task_cards = [
|
||||
_task_to_card(
|
||||
|
||||
Reference in New Issue
Block a user