feat(approvals): enhance approval model with task titles and confidence as float
This commit is contained in:
@@ -26,6 +26,7 @@ from app.db.pagination import paginate
|
||||
from app.db.session import async_session_maker, get_session
|
||||
from app.models.agents import Agent
|
||||
from app.models.approvals import Approval
|
||||
from app.models.tasks import Task
|
||||
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus, ApprovalUpdate
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
from app.services.activity_log import record_activity
|
||||
@@ -96,10 +97,36 @@ async def _approval_task_ids_map(
|
||||
return mapping
|
||||
|
||||
|
||||
def _approval_to_read(approval: Approval, *, task_ids: list[UUID]) -> ApprovalRead:
|
||||
async def _task_titles_by_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
task_ids: set[UUID],
|
||||
) -> dict[UUID, str]:
|
||||
if not task_ids:
|
||||
return {}
|
||||
rows = list(
|
||||
await session.exec(
|
||||
select(col(Task.id), col(Task.title)).where(col(Task.id).in_(task_ids)),
|
||||
),
|
||||
)
|
||||
return {task_id: title for task_id, title in rows}
|
||||
|
||||
|
||||
def _approval_to_read(
|
||||
approval: Approval,
|
||||
*,
|
||||
task_ids: list[UUID],
|
||||
task_titles: list[str],
|
||||
) -> ApprovalRead:
|
||||
primary_task_id = task_ids[0] if task_ids else None
|
||||
model = ApprovalRead.model_validate(approval, from_attributes=True)
|
||||
return model.model_copy(update={"task_id": primary_task_id, "task_ids": task_ids})
|
||||
return model.model_copy(
|
||||
update={
|
||||
"task_id": primary_task_id,
|
||||
"task_ids": task_ids,
|
||||
"task_titles": task_titles,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _approval_reads(
|
||||
@@ -107,8 +134,17 @@ async def _approval_reads(
|
||||
approvals: Sequence[Approval],
|
||||
) -> list[ApprovalRead]:
|
||||
mapping = await _approval_task_ids_map(session, approvals)
|
||||
title_by_id = await _task_titles_by_id(
|
||||
session,
|
||||
task_ids={task_id for task_ids in mapping.values() for task_id in task_ids},
|
||||
)
|
||||
return [
|
||||
_approval_to_read(approval, task_ids=mapping.get(approval.id, [])) for approval in approvals
|
||||
_approval_to_read(
|
||||
approval,
|
||||
task_ids=(task_ids := mapping.get(approval.id, [])),
|
||||
task_titles=[title_by_id[task_id] for task_id in task_ids if task_id in title_by_id],
|
||||
)
|
||||
for approval in approvals
|
||||
]
|
||||
|
||||
|
||||
@@ -389,7 +425,12 @@ async def create_approval(
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(approval)
|
||||
return _approval_to_read(approval, task_ids=task_ids)
|
||||
title_by_id = await _task_titles_by_id(session, task_ids=set(task_ids))
|
||||
return _approval_to_read(
|
||||
approval,
|
||||
task_ids=task_ids,
|
||||
task_titles=[title_by_id[task_id] for task_id in task_ids if task_id in title_by_id],
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{approval_id}", response_model=ApprovalRead)
|
||||
|
||||
@@ -250,9 +250,7 @@ async def _query_wip(
|
||||
if not board_ids:
|
||||
return _wip_series_from_mapping(range_spec, {})
|
||||
|
||||
inbox_bucket_col = func.date_trunc(range_spec.bucket, Task.created_at).label(
|
||||
"inbox_bucket"
|
||||
)
|
||||
inbox_bucket_col = func.date_trunc(range_spec.bucket, Task.created_at).label("inbox_bucket")
|
||||
inbox_statement = (
|
||||
select(inbox_bucket_col, func.count())
|
||||
.where(col(Task.status) == "inbox")
|
||||
@@ -264,9 +262,7 @@ async def _query_wip(
|
||||
)
|
||||
inbox_results = (await session.exec(inbox_statement)).all()
|
||||
|
||||
status_bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label(
|
||||
"status_bucket"
|
||||
)
|
||||
status_bucket_col = func.date_trunc(range_spec.bucket, Task.updated_at).label("status_bucket")
|
||||
progress_case = case((col(Task.status) == "in_progress", 1), else_=0)
|
||||
review_case = case((col(Task.status) == "review", 1), else_=0)
|
||||
done_case = case((col(Task.status) == "done", 1), else_=0)
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
from sqlalchemy import JSON, Column, Float
|
||||
from sqlmodel import Field
|
||||
|
||||
from app.core.time import utcnow
|
||||
@@ -25,7 +25,7 @@ class Approval(QueryModel, table=True):
|
||||
agent_id: UUID | None = Field(default=None, foreign_key="agents.id", index=True)
|
||||
action_type: str
|
||||
payload: dict[str, object] | None = Field(default=None, sa_column=Column(JSON))
|
||||
confidence: int
|
||||
confidence: float = Field(sa_column=Column(Float, nullable=False))
|
||||
rubric_scores: dict[str, int] | None = Field(default=None, sa_column=Column(JSON))
|
||||
status: str = Field(default="pending", index=True)
|
||||
created_at: datetime = Field(default_factory=utcnow)
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlmodel import Field, SQLModel
|
||||
|
||||
ApprovalStatus = Literal["pending", "approved", "rejected"]
|
||||
STATUS_REQUIRED_ERROR = "status is required"
|
||||
LEAD_REASONING_REQUIRED_ERROR = "lead reasoning is required"
|
||||
RUNTIME_ANNOTATION_TYPES = (datetime, UUID)
|
||||
|
||||
|
||||
@@ -21,7 +22,7 @@ class ApprovalBase(SQLModel):
|
||||
task_id: UUID | None = None
|
||||
task_ids: list[UUID] = Field(default_factory=list)
|
||||
payload: dict[str, object] | None = None
|
||||
confidence: int
|
||||
confidence: float = Field(ge=0, le=100)
|
||||
rubric_scores: dict[str, int] | None = None
|
||||
status: ApprovalStatus = "pending"
|
||||
|
||||
@@ -48,6 +49,21 @@ class ApprovalCreate(ApprovalBase):
|
||||
|
||||
agent_id: UUID | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_lead_reasoning(self) -> Self:
|
||||
"""Ensure each approval request includes explicit lead reasoning."""
|
||||
payload = self.payload
|
||||
if isinstance(payload, dict):
|
||||
reason = payload.get("reason")
|
||||
if isinstance(reason, str) and reason.strip():
|
||||
return self
|
||||
decision = payload.get("decision")
|
||||
if isinstance(decision, dict):
|
||||
nested_reason = decision.get("reason")
|
||||
if isinstance(nested_reason, str) and nested_reason.strip():
|
||||
return self
|
||||
raise ValueError(LEAD_REASONING_REQUIRED_ERROR)
|
||||
|
||||
|
||||
class ApprovalUpdate(SQLModel):
|
||||
"""Payload for mutating approval status."""
|
||||
@@ -67,6 +83,7 @@ class ApprovalRead(ApprovalBase):
|
||||
|
||||
id: UUID
|
||||
board_id: UUID
|
||||
task_titles: list[str] = Field(default_factory=list)
|
||||
agent_id: UUID | None = None
|
||||
created_at: datetime
|
||||
resolved_at: datetime | None = None
|
||||
|
||||
@@ -36,10 +36,21 @@ def _memory_to_read(memory: BoardMemory) -> BoardMemoryRead:
|
||||
return BoardMemoryRead.model_validate(memory, from_attributes=True)
|
||||
|
||||
|
||||
def _approval_to_read(approval: Approval, *, task_ids: list[UUID]) -> ApprovalRead:
|
||||
def _approval_to_read(
|
||||
approval: Approval,
|
||||
*,
|
||||
task_ids: list[UUID],
|
||||
task_titles: list[str],
|
||||
) -> ApprovalRead:
|
||||
model = ApprovalRead.model_validate(approval, from_attributes=True)
|
||||
primary_task_id = task_ids[0] if task_ids else None
|
||||
return model.model_copy(update={"task_id": primary_task_id, "task_ids": task_ids})
|
||||
return model.model_copy(
|
||||
update={
|
||||
"task_id": primary_task_id,
|
||||
"task_ids": task_ids,
|
||||
"task_titles": task_titles,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _task_to_card(
|
||||
@@ -137,13 +148,21 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
|
||||
session,
|
||||
approval_ids=approval_ids,
|
||||
)
|
||||
task_title_by_id = {task.id: task.title for task in tasks}
|
||||
approval_reads = [
|
||||
_approval_to_read(
|
||||
approval,
|
||||
task_ids=task_ids_by_approval.get(
|
||||
approval.id,
|
||||
[approval.task_id] if approval.task_id is not None else [],
|
||||
task_ids=(
|
||||
linked_task_ids := task_ids_by_approval.get(
|
||||
approval.id,
|
||||
[approval.task_id] if approval.task_id is not None else [],
|
||||
)
|
||||
),
|
||||
task_titles=[
|
||||
task_title_by_id[task_id]
|
||||
for task_id in linked_task_ids
|
||||
if task_id in task_title_by_id
|
||||
],
|
||||
)
|
||||
for approval in approvals
|
||||
]
|
||||
|
||||
@@ -5,16 +5,16 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
from typing import Mapping
|
||||
|
||||
CONFIDENCE_THRESHOLD = 80
|
||||
CONFIDENCE_THRESHOLD = 80.0
|
||||
MIN_PLANNING_SIGNALS = 2
|
||||
|
||||
|
||||
def compute_confidence(rubric_scores: Mapping[str, int]) -> int:
|
||||
def compute_confidence(rubric_scores: Mapping[str, int]) -> float:
|
||||
"""Compute aggregate confidence from rubric score components."""
|
||||
return int(sum(rubric_scores.values()))
|
||||
return float(sum(rubric_scores.values()))
|
||||
|
||||
|
||||
def approval_required(*, confidence: int, is_external: bool, is_risky: bool) -> bool:
|
||||
def approval_required(*, confidence: float, is_external: bool, is_risky: bool) -> bool:
|
||||
"""Return whether an action must go through explicit approval."""
|
||||
return is_external or is_risky or confidence < CONFIDENCE_THRESHOLD
|
||||
|
||||
|
||||
Reference in New Issue
Block a user