feat: add skill packs management with support for category, risk, and source fields
This commit is contained in:
committed by
Abhimanyu Saharan
parent
da6cc2544b
commit
10748f71a8
@@ -493,6 +493,120 @@ async def _dispatch_gateway_instruction(
|
||||
)
|
||||
|
||||
|
||||
async def _load_pack_skill_count_by_repo(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
organization_id: UUID,
|
||||
) -> dict[str, int]:
|
||||
skills = await MarketplaceSkill.objects.filter_by(organization_id=organization_id).all(session)
|
||||
return _build_skill_count_by_repo(skills)
|
||||
|
||||
|
||||
def _as_skill_pack_read_with_count(
|
||||
*,
|
||||
pack: SkillPack,
|
||||
count_by_repo: dict[str, int],
|
||||
) -> SkillPackRead:
|
||||
return _as_skill_pack_read(pack).model_copy(
|
||||
update={"skill_count": _pack_skill_count(pack=pack, count_by_repo=count_by_repo)},
|
||||
)
|
||||
|
||||
|
||||
async def _sync_gateway_installation_state(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
gateway_id: UUID,
|
||||
skill_id: UUID,
|
||||
installed: bool,
|
||||
) -> None:
|
||||
installation = await GatewayInstalledSkill.objects.filter_by(
|
||||
gateway_id=gateway_id,
|
||||
skill_id=skill_id,
|
||||
).first(session)
|
||||
if installed:
|
||||
if installation is None:
|
||||
session.add(
|
||||
GatewayInstalledSkill(
|
||||
gateway_id=gateway_id,
|
||||
skill_id=skill_id,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
installation.updated_at = utcnow()
|
||||
session.add(installation)
|
||||
return
|
||||
|
||||
if installation is not None:
|
||||
await session.delete(installation)
|
||||
|
||||
|
||||
async def _run_marketplace_skill_action(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
ctx: OrganizationContext,
|
||||
skill_id: UUID,
|
||||
gateway_id: UUID,
|
||||
installed: bool,
|
||||
) -> MarketplaceSkillActionResponse:
|
||||
gateway = await _require_gateway_for_org(gateway_id=gateway_id, session=session, ctx=ctx)
|
||||
require_gateway_workspace_root(gateway)
|
||||
skill = await _require_marketplace_skill_for_org(skill_id=skill_id, session=session, ctx=ctx)
|
||||
instruction = (
|
||||
_install_instruction(skill=skill, gateway=gateway)
|
||||
if installed
|
||||
else _uninstall_instruction(skill=skill, gateway=gateway)
|
||||
)
|
||||
try:
|
||||
await _dispatch_gateway_instruction(
|
||||
session=session,
|
||||
gateway=gateway,
|
||||
message=instruction,
|
||||
)
|
||||
except OpenClawGatewayError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
await _sync_gateway_installation_state(
|
||||
session=session,
|
||||
gateway_id=gateway.id,
|
||||
skill_id=skill.id,
|
||||
installed=installed,
|
||||
)
|
||||
await session.commit()
|
||||
return MarketplaceSkillActionResponse(
|
||||
skill_id=skill.id,
|
||||
gateway_id=gateway.id,
|
||||
installed=installed,
|
||||
)
|
||||
|
||||
|
||||
def _apply_pack_candidate_updates(
|
||||
*,
|
||||
existing: MarketplaceSkill,
|
||||
candidate: PackSkillCandidate,
|
||||
) -> bool:
|
||||
changed = False
|
||||
if existing.name != candidate.name:
|
||||
existing.name = candidate.name
|
||||
changed = True
|
||||
if existing.description != candidate.description:
|
||||
existing.description = candidate.description
|
||||
changed = True
|
||||
if existing.category != candidate.category:
|
||||
existing.category = candidate.category
|
||||
changed = True
|
||||
if existing.risk != candidate.risk:
|
||||
existing.risk = candidate.risk
|
||||
changed = True
|
||||
if existing.source != candidate.source:
|
||||
existing.source = candidate.source
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
|
||||
@router.get("/marketplace", response_model=list[MarketplaceSkillCardRead])
|
||||
async def list_marketplace_skills(
|
||||
gateway_id: UUID = GATEWAY_ID_QUERY,
|
||||
@@ -580,39 +694,11 @@ async def install_marketplace_skill(
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> MarketplaceSkillActionResponse:
|
||||
"""Install a marketplace skill by dispatching instructions to the gateway agent."""
|
||||
gateway = await _require_gateway_for_org(gateway_id=gateway_id, session=session, ctx=ctx)
|
||||
require_gateway_workspace_root(gateway)
|
||||
skill = await _require_marketplace_skill_for_org(skill_id=skill_id, session=session, ctx=ctx)
|
||||
try:
|
||||
await _dispatch_gateway_instruction(
|
||||
session=session,
|
||||
gateway=gateway,
|
||||
message=_install_instruction(skill=skill, gateway=gateway),
|
||||
)
|
||||
except OpenClawGatewayError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
installation = await GatewayInstalledSkill.objects.filter_by(
|
||||
gateway_id=gateway.id,
|
||||
skill_id=skill.id,
|
||||
).first(session)
|
||||
if installation is None:
|
||||
session.add(
|
||||
GatewayInstalledSkill(
|
||||
gateway_id=gateway.id,
|
||||
skill_id=skill.id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
installation.updated_at = utcnow()
|
||||
session.add(installation)
|
||||
await session.commit()
|
||||
return MarketplaceSkillActionResponse(
|
||||
skill_id=skill.id,
|
||||
gateway_id=gateway.id,
|
||||
return await _run_marketplace_skill_action(
|
||||
session=session,
|
||||
ctx=ctx,
|
||||
skill_id=skill_id,
|
||||
gateway_id=gateway_id,
|
||||
installed=True,
|
||||
)
|
||||
|
||||
@@ -628,31 +714,11 @@ async def uninstall_marketplace_skill(
|
||||
ctx: OrganizationContext = ORG_ADMIN_DEP,
|
||||
) -> MarketplaceSkillActionResponse:
|
||||
"""Uninstall a marketplace skill by dispatching instructions to the gateway agent."""
|
||||
gateway = await _require_gateway_for_org(gateway_id=gateway_id, session=session, ctx=ctx)
|
||||
require_gateway_workspace_root(gateway)
|
||||
skill = await _require_marketplace_skill_for_org(skill_id=skill_id, session=session, ctx=ctx)
|
||||
try:
|
||||
await _dispatch_gateway_instruction(
|
||||
session=session,
|
||||
gateway=gateway,
|
||||
message=_uninstall_instruction(skill=skill, gateway=gateway),
|
||||
)
|
||||
except OpenClawGatewayError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
installation = await GatewayInstalledSkill.objects.filter_by(
|
||||
gateway_id=gateway.id,
|
||||
skill_id=skill.id,
|
||||
).first(session)
|
||||
if installation is not None:
|
||||
await session.delete(installation)
|
||||
await session.commit()
|
||||
return MarketplaceSkillActionResponse(
|
||||
skill_id=skill.id,
|
||||
gateway_id=gateway.id,
|
||||
return await _run_marketplace_skill_action(
|
||||
session=session,
|
||||
ctx=ctx,
|
||||
skill_id=skill_id,
|
||||
gateway_id=gateway_id,
|
||||
installed=False,
|
||||
)
|
||||
|
||||
@@ -668,14 +734,12 @@ async def list_skill_packs(
|
||||
.order_by(col(SkillPack.created_at).desc())
|
||||
.all(session)
|
||||
)
|
||||
marketplace_skills = await MarketplaceSkill.objects.filter_by(
|
||||
count_by_repo = await _load_pack_skill_count_by_repo(
|
||||
session=session,
|
||||
organization_id=ctx.organization.id,
|
||||
).all(session)
|
||||
count_by_repo = _build_skill_count_by_repo(marketplace_skills)
|
||||
)
|
||||
return [
|
||||
_as_skill_pack_read(pack).model_copy(
|
||||
update={"skill_count": _pack_skill_count(pack=pack, count_by_repo=count_by_repo)},
|
||||
)
|
||||
_as_skill_pack_read_with_count(pack=pack, count_by_repo=count_by_repo)
|
||||
for pack in packs
|
||||
]
|
||||
|
||||
@@ -688,13 +752,11 @@ async def get_skill_pack(
|
||||
) -> SkillPackRead:
|
||||
"""Get one skill pack by ID."""
|
||||
pack = await _require_skill_pack_for_org(pack_id=pack_id, session=session, ctx=ctx)
|
||||
marketplace_skills = await MarketplaceSkill.objects.filter_by(
|
||||
count_by_repo = await _load_pack_skill_count_by_repo(
|
||||
session=session,
|
||||
organization_id=ctx.organization.id,
|
||||
).all(session)
|
||||
count_by_repo = _build_skill_count_by_repo(marketplace_skills)
|
||||
return _as_skill_pack_read(pack).model_copy(
|
||||
update={"skill_count": _pack_skill_count(pack=pack, count_by_repo=count_by_repo)},
|
||||
)
|
||||
return _as_skill_pack_read_with_count(pack=pack, count_by_repo=count_by_repo)
|
||||
|
||||
|
||||
@router.post("/packs", response_model=SkillPackRead)
|
||||
@@ -722,7 +784,11 @@ async def create_skill_pack(
|
||||
session.add(existing)
|
||||
await session.commit()
|
||||
await session.refresh(existing)
|
||||
return _as_skill_pack_read(existing)
|
||||
count_by_repo = await _load_pack_skill_count_by_repo(
|
||||
session=session,
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
return _as_skill_pack_read_with_count(pack=existing, count_by_repo=count_by_repo)
|
||||
|
||||
pack = SkillPack(
|
||||
organization_id=ctx.organization.id,
|
||||
@@ -733,13 +799,11 @@ async def create_skill_pack(
|
||||
session.add(pack)
|
||||
await session.commit()
|
||||
await session.refresh(pack)
|
||||
marketplace_skills = await MarketplaceSkill.objects.filter_by(
|
||||
count_by_repo = await _load_pack_skill_count_by_repo(
|
||||
session=session,
|
||||
organization_id=ctx.organization.id,
|
||||
).all(session)
|
||||
count_by_repo = _build_skill_count_by_repo(marketplace_skills)
|
||||
return _as_skill_pack_read(pack).model_copy(
|
||||
update={"skill_count": _pack_skill_count(pack=pack, count_by_repo=count_by_repo)},
|
||||
)
|
||||
return _as_skill_pack_read_with_count(pack=pack, count_by_repo=count_by_repo)
|
||||
|
||||
|
||||
@router.patch("/packs/{pack_id}", response_model=SkillPackRead)
|
||||
@@ -770,13 +834,11 @@ async def update_skill_pack(
|
||||
session.add(pack)
|
||||
await session.commit()
|
||||
await session.refresh(pack)
|
||||
marketplace_skills = await MarketplaceSkill.objects.filter_by(
|
||||
count_by_repo = await _load_pack_skill_count_by_repo(
|
||||
session=session,
|
||||
organization_id=ctx.organization.id,
|
||||
).all(session)
|
||||
count_by_repo = _build_skill_count_by_repo(marketplace_skills)
|
||||
return _as_skill_pack_read(pack).model_copy(
|
||||
update={"skill_count": _pack_skill_count(pack=pack, count_by_repo=count_by_repo)},
|
||||
)
|
||||
return _as_skill_pack_read_with_count(pack=pack, count_by_repo=count_by_repo)
|
||||
|
||||
|
||||
@router.delete("/packs/{pack_id}", response_model=OkResponse)
|
||||
@@ -833,23 +895,7 @@ async def sync_skill_pack(
|
||||
created += 1
|
||||
continue
|
||||
|
||||
changed = False
|
||||
if existing.name != candidate.name:
|
||||
existing.name = candidate.name
|
||||
changed = True
|
||||
if existing.description != candidate.description:
|
||||
existing.description = candidate.description
|
||||
changed = True
|
||||
if existing.category != candidate.category:
|
||||
existing.category = candidate.category
|
||||
changed = True
|
||||
if existing.risk != candidate.risk:
|
||||
existing.risk = candidate.risk
|
||||
changed = True
|
||||
if existing.source != candidate.source:
|
||||
existing.source = candidate.source
|
||||
changed = True
|
||||
|
||||
changed = _apply_pack_candidate_updates(existing=existing, candidate=candidate)
|
||||
if changed:
|
||||
existing.updated_at = utcnow()
|
||||
session.add(existing)
|
||||
|
||||
@@ -38,6 +38,9 @@ def upgrade() -> None:
|
||||
sa.Column("organization_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("category", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("risk", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("source", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("source_url", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=False),
|
||||
@@ -105,8 +108,49 @@ def upgrade() -> None:
|
||||
unique=False,
|
||||
)
|
||||
|
||||
if not _has_table("skill_packs"):
|
||||
op.create_table(
|
||||
"skill_packs",
|
||||
sa.Column("id", sa.Uuid(), nullable=False),
|
||||
sa.Column("organization_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("source_url", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"organization_id",
|
||||
"source_url",
|
||||
name="uq_skill_packs_org_source_url",
|
||||
),
|
||||
)
|
||||
|
||||
skill_packs_org_idx = op.f("ix_skill_packs_organization_id")
|
||||
if not _has_index("skill_packs", skill_packs_org_idx):
|
||||
op.create_index(
|
||||
skill_packs_org_idx,
|
||||
"skill_packs",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
skill_packs_org_idx = op.f("ix_skill_packs_organization_id")
|
||||
if _has_index("skill_packs", skill_packs_org_idx):
|
||||
op.drop_index(
|
||||
skill_packs_org_idx,
|
||||
table_name="skill_packs",
|
||||
)
|
||||
|
||||
if _has_table("skill_packs"):
|
||||
op.drop_table("skill_packs")
|
||||
|
||||
gateway_skill_idx = op.f("ix_gateway_installed_skills_skill_id")
|
||||
if _has_index("gateway_installed_skills", gateway_skill_idx):
|
||||
op.drop_index(
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
"""add skill packs table
|
||||
|
||||
Revision ID: d1b2c3e4f5a6
|
||||
Revises: c9d7e9b6a4f2
|
||||
Create Date: 2026-02-14 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d1b2c3e4f5a6"
|
||||
down_revision = "c9d7e9b6a4f2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _has_table(table_name: str) -> bool:
|
||||
return sa.inspect(op.get_bind()).has_table(table_name)
|
||||
|
||||
|
||||
def _has_index(table_name: str, index_name: str) -> bool:
|
||||
if not _has_table(table_name):
|
||||
return False
|
||||
indexes = sa.inspect(op.get_bind()).get_indexes(table_name)
|
||||
return any(index["name"] == index_name for index in indexes)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if not _has_table("skill_packs"):
|
||||
op.create_table(
|
||||
"skill_packs",
|
||||
sa.Column("id", sa.Uuid(), nullable=False),
|
||||
sa.Column("organization_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("source_url", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"organization_id",
|
||||
"source_url",
|
||||
name="uq_skill_packs_org_source_url",
|
||||
),
|
||||
)
|
||||
|
||||
org_idx = op.f("ix_skill_packs_organization_id")
|
||||
if not _has_index("skill_packs", org_idx):
|
||||
op.create_index(
|
||||
org_idx,
|
||||
"skill_packs",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
org_idx = op.f("ix_skill_packs_organization_id")
|
||||
if _has_index("skill_packs", org_idx):
|
||||
op.drop_index(
|
||||
org_idx,
|
||||
table_name="skill_packs",
|
||||
)
|
||||
|
||||
if _has_table("skill_packs"):
|
||||
op.drop_table("skill_packs")
|
||||
@@ -1,57 +0,0 @@
|
||||
"""add marketplace skill metadata fields
|
||||
|
||||
Revision ID: e7a9b1c2d3e4
|
||||
Revises: d1b2c3e4f5a6
|
||||
Create Date: 2026-02-14 00:00:01.000000
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e7a9b1c2d3e4"
|
||||
down_revision = "d1b2c3e4f5a6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _has_table(table_name: str) -> bool:
|
||||
return sa.inspect(op.get_bind()).has_table(table_name)
|
||||
|
||||
|
||||
def _has_column(table_name: str, column_name: str) -> bool:
|
||||
if not _has_table(table_name):
|
||||
return False
|
||||
columns = sa.inspect(op.get_bind()).get_columns(table_name)
|
||||
return any(column["name"] == column_name for column in columns)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if not _has_column("marketplace_skills", "category"):
|
||||
op.add_column(
|
||||
"marketplace_skills",
|
||||
sa.Column("category", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
)
|
||||
if not _has_column("marketplace_skills", "risk"):
|
||||
op.add_column(
|
||||
"marketplace_skills",
|
||||
sa.Column("risk", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
)
|
||||
if not _has_column("marketplace_skills", "source"):
|
||||
op.add_column(
|
||||
"marketplace_skills",
|
||||
sa.Column("source", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if _has_column("marketplace_skills", "source"):
|
||||
op.drop_column("marketplace_skills", "source")
|
||||
if _has_column("marketplace_skills", "risk"):
|
||||
op.drop_column("marketplace_skills", "risk")
|
||||
if _has_column("marketplace_skills", "category"):
|
||||
op.drop_column("marketplace_skills", "category")
|
||||
Reference in New Issue
Block a user