feat: refactor organization context usage in board and gateway endpoints

This commit is contained in:
Abhimanyu Saharan
2026-02-08 21:37:20 +05:30
parent 3f556802a9
commit 061563964d
7 changed files with 37 additions and 32 deletions

View File

@@ -18,6 +18,7 @@ from app.models.agents import Agent
from app.models.board_groups import BoardGroup from app.models.board_groups import BoardGroup
from app.models.boards import Board from app.models.boards import Board
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.models.organization_members import OrganizationMember
from app.schemas.board_group_heartbeat import ( from app.schemas.board_group_heartbeat import (
BoardGroupHeartbeatApply, BoardGroupHeartbeatApply,
BoardGroupHeartbeatApplyResult, BoardGroupHeartbeatApplyResult,
@@ -29,6 +30,7 @@ from app.schemas.view_models import BoardGroupSnapshot
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, sync_gateway_agent_heartbeats from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, sync_gateway_agent_heartbeats
from app.services.board_group_snapshot import build_group_snapshot from app.services.board_group_snapshot import build_group_snapshot
from app.services.organizations import ( from app.services.organizations import (
OrganizationContext,
board_access_filter, board_access_filter,
get_member, get_member,
is_org_admin, is_org_admin,
@@ -49,7 +51,7 @@ async def _require_group_access(
session: AsyncSession, session: AsyncSession,
*, *,
group_id: UUID, group_id: UUID,
member, member: OrganizationMember,
write: bool, write: bool,
) -> BoardGroup: ) -> BoardGroup:
group = await session.get(BoardGroup, group_id) group = await session.get(BoardGroup, group_id)
@@ -80,7 +82,7 @@ async def _require_group_access(
@router.get("", response_model=DefaultLimitOffsetPage[BoardGroupRead]) @router.get("", response_model=DefaultLimitOffsetPage[BoardGroupRead])
async def list_board_groups( async def list_board_groups(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member), ctx: OrganizationContext = Depends(require_org_member),
) -> DefaultLimitOffsetPage[BoardGroupRead]: ) -> DefaultLimitOffsetPage[BoardGroupRead]:
if member_all_boards_read(ctx.member): if member_all_boards_read(ctx.member):
statement = select(BoardGroup).where(col(BoardGroup.organization_id) == ctx.organization.id) statement = select(BoardGroup).where(col(BoardGroup.organization_id) == ctx.organization.id)
@@ -100,7 +102,7 @@ async def list_board_groups(
async def create_board_group( async def create_board_group(
payload: BoardGroupCreate, payload: BoardGroupCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> BoardGroup: ) -> BoardGroup:
data = payload.model_dump() data = payload.model_dump()
if not (data.get("slug") or "").strip(): if not (data.get("slug") or "").strip():
@@ -113,7 +115,7 @@ async def create_board_group(
async def get_board_group( async def get_board_group(
group_id: UUID, group_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member), ctx: OrganizationContext = Depends(require_org_member),
) -> BoardGroup: ) -> BoardGroup:
return await _require_group_access(session, group_id=group_id, member=ctx.member, write=False) return await _require_group_access(session, group_id=group_id, member=ctx.member, write=False)
@@ -124,7 +126,7 @@ async def get_board_group_snapshot(
include_done: bool = False, include_done: bool = False,
per_board_task_limit: int = 5, per_board_task_limit: int = 5,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member), ctx: OrganizationContext = Depends(require_org_member),
) -> BoardGroupSnapshot: ) -> BoardGroupSnapshot:
group = await _require_group_access(session, group_id=group_id, member=ctx.member, write=False) group = await _require_group_access(session, group_id=group_id, member=ctx.member, write=False)
if per_board_task_limit < 0: if per_board_task_limit < 0:
@@ -253,7 +255,7 @@ async def update_board_group(
payload: BoardGroupUpdate, payload: BoardGroupUpdate,
group_id: UUID, group_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> BoardGroup: ) -> BoardGroup:
group = await _require_group_access(session, group_id=group_id, member=ctx.member, write=True) group = await _require_group_access(session, group_id=group_id, member=ctx.member, write=True)
updates = payload.model_dump(exclude_unset=True) updates = payload.model_dump(exclude_unset=True)
@@ -269,7 +271,7 @@ async def update_board_group(
async def delete_board_group( async def delete_board_group(
group_id: UUID, group_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> OkResponse: ) -> OkResponse:
await _require_group_access(session, group_id=group_id, member=ctx.member, write=True) await _require_group_access(session, group_id=group_id, member=ctx.member, write=True)

View File

@@ -43,7 +43,7 @@ from app.schemas.pagination import DefaultLimitOffsetPage
from app.schemas.view_models import BoardGroupSnapshot, BoardSnapshot from app.schemas.view_models import BoardGroupSnapshot, BoardSnapshot
from app.services.board_group_snapshot import build_board_group_snapshot from app.services.board_group_snapshot import build_board_group_snapshot
from app.services.board_snapshot import build_board_snapshot from app.services.board_snapshot import build_board_snapshot
from app.services.organizations import board_access_filter from app.services.organizations import OrganizationContext, board_access_filter
router = APIRouter(prefix="/boards", tags=["boards"]) router = APIRouter(prefix="/boards", tags=["boards"])
@@ -81,7 +81,7 @@ async def _require_gateway(
async def _require_gateway_for_create( async def _require_gateway_for_create(
payload: BoardCreate, payload: BoardCreate,
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
) -> Gateway: ) -> Gateway:
return await _require_gateway(session, payload.gateway_id, organization_id=ctx.organization.id) return await _require_gateway(session, payload.gateway_id, organization_id=ctx.organization.id)
@@ -109,7 +109,7 @@ async def _require_board_group(
async def _require_board_group_for_create( async def _require_board_group_for_create(
payload: BoardCreate, payload: BoardCreate,
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
) -> BoardGroup | None: ) -> BoardGroup | None:
if payload.board_group_id is None: if payload.board_group_id is None:
@@ -220,7 +220,7 @@ async def list_boards(
gateway_id: UUID | None = Query(default=None), gateway_id: UUID | None = Query(default=None),
board_group_id: UUID | None = Query(default=None), board_group_id: UUID | None = Query(default=None),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member), ctx: OrganizationContext = Depends(require_org_member),
) -> DefaultLimitOffsetPage[BoardRead]: ) -> DefaultLimitOffsetPage[BoardRead]:
statement = select(Board).where(board_access_filter(ctx.member, write=False)) statement = select(Board).where(board_access_filter(ctx.member, write=False))
if gateway_id is not None: if gateway_id is not None:
@@ -237,7 +237,7 @@ async def create_board(
_gateway: Gateway = Depends(_require_gateway_for_create), _gateway: Gateway = Depends(_require_gateway_for_create),
_board_group: BoardGroup | None = Depends(_require_board_group_for_create), _board_group: BoardGroup | None = Depends(_require_board_group_for_create),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> Board: ) -> Board:
data = payload.model_dump() data = payload.model_dump()
data["organization_id"] = ctx.organization.id data["organization_id"] = ctx.organization.id

View File

@@ -25,6 +25,7 @@ from app.schemas.gateways import (
) )
from app.schemas.pagination import DefaultLimitOffsetPage from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_main_agent from app.services.agent_provisioning import DEFAULT_HEARTBEAT_CONFIG, provision_main_agent
from app.services.organizations import OrganizationContext
from app.services.template_sync import sync_gateway_templates as sync_gateway_templates_service from app.services.template_sync import sync_gateway_templates as sync_gateway_templates_service
router = APIRouter(prefix="/gateways", tags=["gateways"]) router = APIRouter(prefix="/gateways", tags=["gateways"])
@@ -131,7 +132,7 @@ async def _ensure_main_agent(
@router.get("", response_model=DefaultLimitOffsetPage[GatewayRead]) @router.get("", response_model=DefaultLimitOffsetPage[GatewayRead])
async def list_gateways( async def list_gateways(
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> DefaultLimitOffsetPage[GatewayRead]: ) -> DefaultLimitOffsetPage[GatewayRead]:
statement = ( statement = (
select(Gateway) select(Gateway)
@@ -146,7 +147,7 @@ async def create_gateway(
payload: GatewayCreate, payload: GatewayCreate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> Gateway: ) -> Gateway:
data = payload.model_dump() data = payload.model_dump()
data["organization_id"] = ctx.organization.id data["organization_id"] = ctx.organization.id
@@ -162,7 +163,7 @@ async def create_gateway(
async def get_gateway( async def get_gateway(
gateway_id: UUID, gateway_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> Gateway: ) -> Gateway:
gateway = await session.get(Gateway, gateway_id) gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id: if gateway is None or gateway.organization_id != ctx.organization.id:
@@ -176,7 +177,7 @@ async def update_gateway(
payload: GatewayUpdate, payload: GatewayUpdate,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> Gateway: ) -> Gateway:
gateway = await session.get(Gateway, gateway_id) gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id: if gateway is None or gateway.organization_id != ctx.organization.id:
@@ -210,7 +211,7 @@ async def sync_gateway_templates(
board_id: UUID | None = Query(default=None), board_id: UUID | None = Query(default=None),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> GatewayTemplatesSyncResult: ) -> GatewayTemplatesSyncResult:
gateway = await session.get(Gateway, gateway_id) gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id: if gateway is None or gateway.organization_id != ctx.organization.id:
@@ -231,7 +232,7 @@ async def sync_gateway_templates(
async def delete_gateway( async def delete_gateway(
gateway_id: UUID, gateway_id: UUID,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_admin), ctx: OrganizationContext = Depends(require_org_admin),
) -> OkResponse: ) -> OkResponse:
gateway = await session.get(Gateway, gateway_id) gateway = await session.get(Gateway, gateway_id)
if gateway is None or gateway.organization_id != ctx.organization.id: if gateway is None or gateway.organization_id != ctx.organization.id:

View File

@@ -26,7 +26,7 @@ from app.schemas.metrics import (
DashboardWipRangeSeries, DashboardWipRangeSeries,
DashboardWipSeriesSet, DashboardWipSeriesSet,
) )
from app.services.organizations import list_accessible_board_ids from app.services.organizations import OrganizationContext, list_accessible_board_ids
router = APIRouter(prefix="/metrics", tags=["metrics"]) router = APIRouter(prefix="/metrics", tags=["metrics"])
@@ -304,7 +304,7 @@ async def _tasks_in_progress(session: AsyncSession, board_ids: list[UUID]) -> in
async def dashboard_metrics( async def dashboard_metrics(
range: Literal["24h", "7d"] = Query(default="24h"), range: Literal["24h", "7d"] = Query(default="24h"),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
ctx=Depends(require_org_member), ctx: OrganizationContext = Depends(require_org_member),
) -> DashboardMetrics: ) -> DashboardMetrics:
primary = _resolve_range(range) primary = _resolve_range(range)
comparison = _comparison_range(range) comparison = _comparison_range(range)

View File

@@ -5,7 +5,7 @@ from typing import Any, Sequence
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func from sqlalchemy import delete, func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -168,7 +168,7 @@ async def get_my_membership(
) )
model = _member_to_read(ctx.member, user) model = _member_to_read(ctx.member, user)
model.board_access = [ model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) # type: ignore[name-defined] OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
for row in access_rows for row in access_rows
] ]
return model return model
@@ -216,7 +216,7 @@ async def get_org_member(
) )
model = _member_to_read(member, user) model = _member_to_read(member, user)
model.board_access = [ model.board_access = [
OrganizationBoardAccessRead.model_validate(row, from_attributes=True) # type: ignore[name-defined] OrganizationBoardAccessRead.model_validate(row, from_attributes=True)
for row in access_rows for row in access_rows
] ]
return model return model
@@ -351,9 +351,9 @@ async def revoke_org_invite(
if invite is None or invite.organization_id != ctx.organization.id: if invite is None or invite.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await session.execute( await session.execute(
OrganizationInviteBoardAccess.__table__.delete().where( delete(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
) ),
) )
await session.delete(invite) await session.delete(invite)
await session.commit() await session.commit()

View File

@@ -5,7 +5,8 @@ from typing import Iterable
from uuid import UUID from uuid import UUID
from fastapi import HTTPException, status from fastapi import HTTPException, status
from sqlalchemy import func, or_ from sqlalchemy import delete, func, or_
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -267,7 +268,7 @@ async def require_board_access(
return member return member
def board_access_filter(member: OrganizationMember, *, write: bool) -> object: def board_access_filter(member: OrganizationMember, *, write: bool) -> ColumnElement[bool]:
if write and member_all_boards_write(member): if write and member_all_boards_write(member):
return col(Board.organization_id) == member.organization_id return col(Board.organization_id) == member.organization_id
if not write and member_all_boards_read(member): if not write and member_all_boards_read(member):
@@ -330,9 +331,9 @@ async def apply_member_access_update(
session.add(member) session.add(member)
await session.execute( await session.execute(
OrganizationBoardAccess.__table__.delete().where( delete(OrganizationBoardAccess).where(
col(OrganizationBoardAccess.organization_member_id) == member.id col(OrganizationBoardAccess.organization_member_id) == member.id
) ),
) )
if update.all_boards_read or update.all_boards_write: if update.all_boards_read or update.all_boards_write:
@@ -360,9 +361,9 @@ async def apply_invite_board_access(
entries: Iterable[OrganizationBoardAccessSpec], entries: Iterable[OrganizationBoardAccessSpec],
) -> None: ) -> None:
await session.execute( await session.execute(
OrganizationInviteBoardAccess.__table__.delete().where( delete(OrganizationInviteBoardAccess).where(
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
) ),
) )
if invite.all_boards_read or invite.all_boards_write: if invite.all_boards_read or invite.all_boards_write:
return return

View File

@@ -3,7 +3,8 @@ import { clerkSetup } from "@clerk/testing/cypress";
export default defineConfig({ export default defineConfig({
env: { env: {
NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY: process.env.NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY, NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY:
process.env.NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY,
// Optional overrides. // Optional overrides.
CLERK_ORIGIN: process.env.CYPRESS_CLERK_ORIGIN, CLERK_ORIGIN: process.env.CYPRESS_CLERK_ORIGIN,
CLERK_TEST_EMAIL: process.env.CYPRESS_CLERK_TEST_EMAIL, CLERK_TEST_EMAIL: process.env.CYPRESS_CLERK_TEST_EMAIL,