feat: implement task dependencies with validation and update handling

This commit is contained in:
Abhimanyu Saharan
2026-02-07 00:21:44 +05:30
parent 8970ee6742
commit 4bab455912
34 changed files with 1241 additions and 157 deletions

View File

@@ -338,9 +338,7 @@ def _render_agent_files(
rendered[name] = env.from_string(override).render(**context).strip()
continue
template_name = (
template_overrides[name]
if template_overrides and name in template_overrides
else name
template_overrides[name] if template_overrides and name in template_overrides else name
)
path = _templates_root() / template_name
if path.exists():

View File

@@ -19,6 +19,11 @@ from app.schemas.approvals import ApprovalRead
from app.schemas.board_memory import BoardMemoryRead
from app.schemas.boards import BoardRead
from app.schemas.view_models import BoardSnapshot, TaskCardRead
from app.services.task_dependencies import (
blocked_by_dependency_ids,
dependency_ids_by_task_id,
dependency_status_by_id,
)
OFFLINE_AFTER = timedelta(minutes=10)
@@ -42,7 +47,9 @@ async def _gateway_main_session_keys(session: AsyncSession) -> set[str]:
def _agent_to_read(agent: Agent, main_session_keys: set[str]) -> AgentRead:
model = AgentRead.model_validate(agent, from_attributes=True)
computed_status = _computed_agent_status(agent)
is_gateway_main = bool(agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys)
is_gateway_main = bool(
agent.openclaw_session_id and agent.openclaw_session_id in main_session_keys
)
return model.model_copy(update={"status": computed_status, "is_gateway_main": is_gateway_main})
@@ -59,17 +66,29 @@ def _task_to_card(
*,
agent_name_by_id: dict[UUID, str],
counts_by_task_id: dict[UUID, tuple[int, int]],
deps_by_task_id: dict[UUID, list[UUID]],
dependency_status_by_id_map: dict[UUID, str],
) -> TaskCardRead:
card = TaskCardRead.model_validate(task, from_attributes=True)
approvals_count, approvals_pending_count = counts_by_task_id.get(task.id, (0, 0))
assignee = (
agent_name_by_id.get(task.assigned_agent_id) if task.assigned_agent_id is not None else None
)
depends_on_task_ids = deps_by_task_id.get(task.id, [])
blocked_by_task_ids = blocked_by_dependency_ids(
dependency_ids=depends_on_task_ids,
status_by_id=dependency_status_by_id_map,
)
if task.status == "done":
blocked_by_task_ids = []
return card.model_copy(
update={
"assignee": assignee,
"approvals_count": approvals_count,
"approvals_pending_count": approvals_pending_count,
"depends_on_task_ids": depends_on_task_ids,
"blocked_by_task_ids": blocked_by_task_ids,
"is_blocked": bool(blocked_by_task_ids),
}
)
@@ -82,22 +101,37 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
select(Task).where(col(Task.board_id) == board.id).order_by(col(Task.created_at).desc())
)
)
task_ids = [task.id for task in tasks]
deps_by_task_id = await dependency_ids_by_task_id(session, board_id=board.id, task_ids=task_ids)
all_dependency_ids: list[UUID] = []
for values in deps_by_task_id.values():
all_dependency_ids.extend(values)
dependency_status_by_id_map = await dependency_status_by_id(
session,
board_id=board.id,
dependency_ids=list({*all_dependency_ids}),
)
main_session_keys = await _gateway_main_session_keys(session)
agents = list(
await session.exec(
select(Agent).where(col(Agent.board_id) == board.id).order_by(col(Agent.created_at).desc())
select(Agent)
.where(col(Agent.board_id) == board.id)
.order_by(col(Agent.created_at).desc())
)
)
agent_reads = [_agent_to_read(agent, main_session_keys) for agent in agents]
agent_name_by_id = {agent.id: agent.name for agent in agents}
pending_approvals_count = int(
(await session.exec(
select(func.count(col(Approval.id)))
.where(col(Approval.board_id) == board.id)
.where(col(Approval.status) == "pending")
)).one()
(
await session.exec(
select(func.count(col(Approval.id)))
.where(col(Approval.board_id) == board.id)
.where(col(Approval.status) == "pending")
)
).one()
)
approvals = list(
@@ -129,7 +163,13 @@ async def build_board_snapshot(session: AsyncSession, board: Board) -> BoardSnap
counts_by_task_id[task_id] = (int(total or 0), int(pending or 0))
task_cards = [
_task_to_card(task, agent_name_by_id=agent_name_by_id, counts_by_task_id=counts_by_task_id)
_task_to_card(
task,
agent_name_by_id=agent_name_by_id,
counts_by_task_id=counts_by_task_id,
deps_by_task_id=deps_by_task_id,
dependency_status_by_id_map=dependency_status_by_id_map,
)
for task in tasks
]

View File

@@ -0,0 +1,224 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Final
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy import delete
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.task_dependencies import TaskDependency
from app.models.tasks import Task
DONE_STATUS: Final[str] = "done"
def _dedupe_uuid_list(values: Sequence[UUID]) -> list[UUID]:
# Preserve order; remove duplicates.
seen: set[UUID] = set()
output: list[UUID] = []
for value in values:
if value in seen:
continue
seen.add(value)
output.append(value)
return output
async def dependency_ids_by_task_id(
session: AsyncSession,
*,
board_id: UUID,
task_ids: Sequence[UUID],
) -> dict[UUID, list[UUID]]:
if not task_ids:
return {}
rows = list(
await session.exec(
select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id))
.where(col(TaskDependency.board_id) == board_id)
.where(col(TaskDependency.task_id).in_(task_ids))
.order_by(col(TaskDependency.created_at).asc())
)
)
mapping: dict[UUID, list[UUID]] = defaultdict(list)
for task_id, depends_on_task_id in rows:
mapping[task_id].append(depends_on_task_id)
return dict(mapping)
async def dependency_status_by_id(
session: AsyncSession,
*,
board_id: UUID,
dependency_ids: Sequence[UUID],
) -> dict[UUID, str]:
if not dependency_ids:
return {}
rows = list(
await session.exec(
select(col(Task.id), col(Task.status))
.where(col(Task.board_id) == board_id)
.where(col(Task.id).in_(dependency_ids))
)
)
return {task_id: status_value for task_id, status_value in rows}
def blocked_by_dependency_ids(
*,
dependency_ids: Sequence[UUID],
status_by_id: Mapping[UUID, str],
) -> list[UUID]:
blocked: list[UUID] = []
for dep_id in dependency_ids:
if status_by_id.get(dep_id) != DONE_STATUS:
blocked.append(dep_id)
return blocked
async def blocked_by_for_task(
session: AsyncSession,
*,
board_id: UUID,
task_id: UUID,
dependency_ids: Sequence[UUID] | None = None,
) -> list[UUID]:
dep_ids = list(dependency_ids or [])
if dependency_ids is None:
deps_map = await dependency_ids_by_task_id(
session,
board_id=board_id,
task_ids=[task_id],
)
dep_ids = deps_map.get(task_id, [])
if not dep_ids:
return []
status_by_id = await dependency_status_by_id(session, board_id=board_id, dependency_ids=dep_ids)
return blocked_by_dependency_ids(dependency_ids=dep_ids, status_by_id=status_by_id)
def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool:
visited: set[UUID] = set()
in_stack: set[UUID] = set()
def dfs(node: UUID) -> bool:
if node in in_stack:
return True
if node in visited:
return False
visited.add(node)
in_stack.add(node)
for nxt in edges.get(node, set()):
if dfs(nxt):
return True
in_stack.remove(node)
return False
for node in nodes:
if dfs(node):
return True
return False
async def validate_dependency_update(
session: AsyncSession,
*,
board_id: UUID,
task_id: UUID,
depends_on_task_ids: Sequence[UUID],
) -> list[UUID]:
normalized = _dedupe_uuid_list(depends_on_task_ids)
if task_id in normalized:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Task cannot depend on itself.",
)
if not normalized:
return []
# Ensure all dependency tasks exist on this board.
existing_ids = set(
await session.exec(
select(col(Task.id))
.where(col(Task.board_id) == board_id)
.where(col(Task.id).in_(normalized))
)
)
missing = [dep_id for dep_id in normalized if dep_id not in existing_ids]
if missing:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"message": "One or more dependency tasks were not found on this board.",
"missing_task_ids": [str(value) for value in missing],
},
)
# Ensure the dependency graph is acyclic after applying the update.
task_ids = list(await session.exec(select(col(Task.id)).where(col(Task.board_id) == board_id)))
rows = list(
await session.exec(
select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id)).where(
col(TaskDependency.board_id) == board_id
)
)
)
edges: dict[UUID, set[UUID]] = defaultdict(set)
for src, dst in rows:
edges[src].add(dst)
edges[task_id] = set(normalized)
if _has_cycle(task_ids, edges):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Dependency cycle detected. Remove the cycle before saving.",
)
return normalized
async def replace_task_dependencies(
session: AsyncSession,
*,
board_id: UUID,
task_id: UUID,
depends_on_task_ids: Sequence[UUID],
) -> list[UUID]:
normalized = await validate_dependency_update(
session,
board_id=board_id,
task_id=task_id,
depends_on_task_ids=depends_on_task_ids,
)
await session.execute(
delete(TaskDependency)
.where(col(TaskDependency.board_id) == board_id)
.where(col(TaskDependency.task_id) == task_id)
)
for dep_id in normalized:
session.add(
TaskDependency(
board_id=board_id,
task_id=task_id,
depends_on_task_id=dep_id,
)
)
return normalized
async def dependent_task_ids(
session: AsyncSession,
*,
board_id: UUID,
dependency_task_id: UUID,
) -> list[UUID]:
rows = await session.exec(
select(col(TaskDependency.task_id))
.where(col(TaskDependency.board_id) == board_id)
.where(col(TaskDependency.depends_on_task_id) == dependency_task_id)
)
return list(rows)