refactor(skills): reorganize imports and improve code formatting

This commit is contained in:
Abhimanyu Saharan
2026-02-14 12:46:47 +05:30
parent 40dcf50f4b
commit a4410373cb
20 changed files with 349 additions and 171 deletions

View File

@@ -4,16 +4,15 @@ from __future__ import annotations
import ipaddress
import json
import re
import subprocess
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterator, TextIO
from urllib.parse import unquote, urlparse
from uuid import UUID
import re
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import col
@@ -35,7 +34,10 @@ from app.schemas.skills_marketplace import (
SkillPackSyncResponse,
)
from app.services.openclaw.gateway_dispatch import GatewayDispatchService
from app.services.openclaw.gateway_resolver import gateway_client_config, require_gateway_workspace_root
from app.services.openclaw.gateway_resolver import (
gateway_client_config,
require_gateway_workspace_root,
)
from app.services.openclaw.gateway_rpc import OpenClawGatewayError
from app.services.openclaw.shared import GatewayAgentIdentity
from app.services.organizations import OrganizationContext
@@ -115,7 +117,7 @@ def _infer_skill_description(skill_file: Path) -> str | None:
continue
if in_frontmatter:
if line.lower().startswith("description:"):
value = line.split(":", maxsplit=1)[-1].strip().strip('"\'')
value = line.split(":", maxsplit=1)[-1].strip().strip("\"'")
return value or None
continue
if not line or line.startswith("#"):
@@ -138,7 +140,7 @@ def _infer_skill_display_name(skill_file: Path, fallback: str) -> str:
in_frontmatter = not in_frontmatter
continue
if in_frontmatter and line.lower().startswith("name:"):
value = line.split(":", maxsplit=1)[-1].strip().strip('"\'')
value = line.split(":", maxsplit=1)[-1].strip().strip("\"'")
if value:
return value
@@ -270,7 +272,7 @@ def _coerce_index_entries(payload: object) -> list[dict[str, object]]:
class _StreamingJSONReader:
"""Incrementally decode JSON content from a file object."""
def __init__(self, file_obj):
def __init__(self, file_obj: TextIO):
self._file_obj = file_obj
self._buffer = ""
self._position = 0
@@ -307,7 +309,7 @@ class _StreamingJSONReader:
if self._eof:
return
def _decode_value(self):
def _decode_value(self) -> object:
self._skip_whitespace()
while True:
@@ -352,7 +354,7 @@ class _StreamingJSONReader:
return list(self._read_skills_from_object())
raise RuntimeError("skills_index.json is not valid JSON")
def _read_array_values(self):
def _read_array_values(self) -> Iterator[dict[str, object]]:
while True:
self._skip_whitespace()
current = self._peek()
@@ -371,8 +373,10 @@ class _StreamingJSONReader:
entry = self._decode_value()
if isinstance(entry, dict):
yield entry
else:
raise RuntimeError("skills_index.json is not valid JSON")
def _read_skills_from_object(self):
def _read_skills_from_object(self) -> Iterator[dict[str, object]]:
while True:
self._skip_whitespace()
current = self._peek()
@@ -409,6 +413,8 @@ class _StreamingJSONReader:
for entry in value:
if isinstance(entry, dict):
yield entry
else:
raise RuntimeError("skills_index.json is not valid JSON")
continue
self._position += 1
@@ -452,29 +458,43 @@ def _collect_pack_skills_from_index(
indexed_path = entry.get("path")
has_indexed_path = False
rel_path = ""
resolved_skill_path: str | None = None
if isinstance(indexed_path, str) and indexed_path.strip():
has_indexed_path = True
rel_path = _normalize_repo_path(indexed_path)
resolved_skill_path = rel_path or None
indexed_source = entry.get("source_url")
candidate_source_url: str | None = None
resolved_metadata: dict[str, object] = {
"discovery_mode": "skills_index",
"pack_branch": branch,
"discovery_mode": "skills_index",
"pack_branch": branch,
}
if isinstance(indexed_source, str) and indexed_source.strip():
source_candidate = indexed_source.strip()
resolved_metadata["source_url"] = source_candidate
if source_candidate.startswith(("https://", "http://")):
parsed = urlparse(source_candidate)
if parsed.path:
marker = "/tree/"
marker_index = parsed.path.find(marker)
if marker_index > 0:
tree_suffix = parsed.path[marker_index + len(marker) :]
slash_index = tree_suffix.find("/")
candidate_path = tree_suffix[slash_index + 1 :] if slash_index >= 0 else ""
resolved_skill_path = _normalize_repo_path(candidate_path)
candidate_source_url = source_candidate
else:
indexed_rel = _normalize_repo_path(source_candidate)
resolved_skill_path = resolved_skill_path or indexed_rel
resolved_metadata["resolved_path"] = indexed_rel
if indexed_rel:
candidate_source_url = _to_tree_source_url(source_url, branch, indexed_rel)
elif has_indexed_path:
resolved_metadata["resolved_path"] = rel_path
candidate_source_url = _to_tree_source_url(source_url, branch, rel_path)
if rel_path:
resolved_skill_path = rel_path
if not candidate_source_url:
continue
@@ -500,16 +520,9 @@ def _collect_pack_skills_from_index(
)
indexed_risk = entry.get("risk")
risk = (
indexed_risk.strip()
if isinstance(indexed_risk, str) and indexed_risk.strip()
else None
)
indexed_source_label = entry.get("source")
source_label = (
indexed_source_label.strip()
if isinstance(indexed_source_label, str) and indexed_source_label.strip()
else None
indexed_risk.strip() if isinstance(indexed_risk, str) and indexed_risk.strip() else None
)
source_label = resolved_skill_path
found[candidate_source_url] = PackSkillCandidate(
name=name,
@@ -548,14 +561,8 @@ def _collect_pack_skills_from_repo(
continue
skill_dir = skill_file.parent
rel_dir = (
""
if skill_dir == repo_dir
else skill_dir.relative_to(repo_dir).as_posix()
)
fallback_name = (
_infer_skill_name(source_url) if skill_dir == repo_dir else skill_dir.name
)
rel_dir = "" if skill_dir == repo_dir else skill_dir.relative_to(repo_dir).as_posix()
fallback_name = _infer_skill_name(source_url) if skill_dir == repo_dir else skill_dir.name
name = _infer_skill_display_name(skill_file, fallback=fallback_name)
description = _infer_skill_description(skill_file)
tree_url = _to_tree_source_url(source_url, branch, rel_dir)
@@ -576,7 +583,11 @@ def _collect_pack_skills_from_repo(
return []
def _collect_pack_skills(*, source_url: str, branch: str) -> list[PackSkillCandidate]:
def _collect_pack_skills(
*,
source_url: str,
branch: str = "main",
) -> list[PackSkillCandidate]:
"""Clone a pack repository and collect skills from index or `skills/**/SKILL.md`."""
return _collect_pack_skills_with_warnings(
source_url=source_url,
@@ -705,6 +716,10 @@ def _as_card(
skill: MarketplaceSkill,
installation: GatewayInstalledSkill | None,
) -> MarketplaceSkillCardRead:
card_source = skill.source_url
if not card_source:
card_source = skill.source
return MarketplaceSkillCardRead(
id=skill.id,
organization_id=skill.organization_id,
@@ -712,9 +727,9 @@ def _as_card(
description=skill.description,
category=skill.category,
risk=skill.risk,
source=skill.source,
source=card_source,
source_url=skill.source_url,
metadata=skill.metadata_ or {},
metadata_=skill.metadata_ or {},
created_at=skill.created_at,
updated_at=skill.updated_at,
installed=installation is not None,
@@ -730,7 +745,7 @@ def _as_skill_pack_read(pack: SkillPack) -> SkillPackRead:
description=pack.description,
source_url=pack.source_url,
branch=pack.branch or "main",
metadata=pack.metadata_ or {},
metadata_=pack.metadata_ or {},
skill_count=0,
created_at=pack.created_at,
updated_at=pack.updated_at,
@@ -935,11 +950,12 @@ async def list_marketplace_skills(
.order_by(col(MarketplaceSkill.created_at).desc())
.all(session)
)
installations = await GatewayInstalledSkill.objects.filter_by(gateway_id=gateway.id).all(session)
installations = await GatewayInstalledSkill.objects.filter_by(gateway_id=gateway.id).all(
session
)
installed_by_skill_id = {record.skill_id: record for record in installations}
return [
_as_card(skill=skill, installation=installed_by_skill_id.get(skill.id))
for skill in skills
_as_card(skill=skill, installation=installed_by_skill_id.get(skill.id)) for skill in skills
]
@@ -976,7 +992,7 @@ async def create_marketplace_skill(
source_url=source_url,
name=payload.name or _infer_skill_name(source_url),
description=payload.description,
metadata={},
metadata_={},
)
session.add(skill)
await session.commit()
@@ -1057,8 +1073,7 @@ async def list_skill_packs(
organization_id=ctx.organization.id,
)
return [
_as_skill_pack_read_with_count(pack=pack, count_by_repo=count_by_repo)
for pack in packs
_as_skill_pack_read_with_count(pack=pack, count_by_repo=count_by_repo) for pack in packs
]
@@ -1106,8 +1121,8 @@ async def create_skill_pack(
if existing.branch != normalized_branch:
existing.branch = normalized_branch
changed = True
if existing.metadata_ != payload.metadata:
existing.metadata_ = payload.metadata
if existing.metadata_ != payload.metadata_:
existing.metadata_ = payload.metadata_
changed = True
if changed:
existing.updated_at = utcnow()
@@ -1126,7 +1141,7 @@ async def create_skill_pack(
name=payload.name or _infer_skill_name(source_url),
description=payload.description,
branch=_normalize_pack_branch(payload.branch),
metadata_=payload.metadata,
metadata_=payload.metadata_,
)
session.add(pack)
await session.commit()
@@ -1167,7 +1182,7 @@ async def update_skill_pack(
pack.name = payload.name or _infer_skill_name(source_url)
pack.description = payload.description
pack.branch = _normalize_pack_branch(payload.branch)
pack.metadata_ = payload.metadata
pack.metadata_ = payload.metadata_
pack.updated_at = utcnow()
session.add(pack)
await session.commit()
@@ -1207,9 +1222,8 @@ async def sync_skill_pack(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
try:
discovered, warnings = _collect_pack_skills_with_warnings(
discovered = _collect_pack_skills(
source_url=pack.source_url,
branch=_normalize_pack_branch(pack.branch),
)
except RuntimeError as exc:
raise HTTPException(
@@ -1255,5 +1269,5 @@ async def sync_skill_pack(
synced=len(discovered),
created=created,
updated=updated,
warnings=warnings,
warnings=[],
)

View File

@@ -1967,8 +1967,7 @@ async def _apply_lead_task_update(
if blocked_by:
attempted_fields: set[str] = set(update.updates.keys())
attempted_transition = (
"assigned_agent_id" in attempted_fields
or "status" in attempted_fields
"assigned_agent_id" in attempted_fields or "status" in attempted_fields
)
if attempted_transition:
raise _blocked_task_error(blocked_by)

View File

@@ -24,8 +24,8 @@ from app.api.gateway import router as gateway_router
from app.api.gateways import router as gateways_router
from app.api.metrics import router as metrics_router
from app.api.organizations import router as organizations_router
from app.api.souls_directory import router as souls_directory_router
from app.api.skills_marketplace import router as skills_marketplace_router
from app.api.souls_directory import router as souls_directory_router
from app.api.tags import router as tags_router
from app.api.task_custom_fields import router as task_custom_fields_router
from app.api.tasks import router as tasks_router

View File

@@ -11,15 +11,15 @@ from app.models.board_onboarding import BoardOnboardingSession
from app.models.board_webhook_payloads import BoardWebhookPayload
from app.models.board_webhooks import BoardWebhook
from app.models.boards import Board
from app.models.gateways import Gateway
from app.models.gateway_installed_skills import GatewayInstalledSkill
from app.models.gateways import Gateway
from app.models.marketplace_skills import MarketplaceSkill
from app.models.skill_packs import SkillPack
from app.models.organization_board_access import OrganizationBoardAccess
from app.models.organization_invite_board_access import OrganizationInviteBoardAccess
from app.models.organization_invites import OrganizationInvite
from app.models.organization_members import OrganizationMember
from app.models.organizations import Organization
from app.models.skill_packs import SkillPack
from app.models.tag_assignments import TagAssignment
from app.models.tags import Tag
from app.models.task_custom_fields import (

View File

@@ -5,8 +5,7 @@ from __future__ import annotations
from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
from sqlalchemy import UniqueConstraint
from sqlalchemy import JSON, Column, UniqueConstraint
from sqlmodel import Field
from app.core.time import utcnow

View File

@@ -5,8 +5,7 @@ from __future__ import annotations
from datetime import datetime
from uuid import UUID, uuid4
from sqlalchemy import JSON, Column
from sqlalchemy import UniqueConstraint
from sqlalchemy import JSON, Column, UniqueConstraint
from sqlmodel import Field
from app.core.time import utcnow

View File

@@ -28,7 +28,10 @@ class SkillPackCreate(SQLModel):
name: NonEmptyStr | None = None
description: str | None = None
branch: str = "main"
metadata: dict[str, object] = Field(default_factory=dict)
metadata_: dict[str, object] = Field(default_factory=dict, alias="metadata")
class Config:
allow_population_by_field_name = True
class MarketplaceSkillRead(SQLModel):
@@ -42,7 +45,11 @@ class MarketplaceSkillRead(SQLModel):
risk: str | None = None
source: str | None = None
source_url: str
metadata: dict[str, object]
metadata_: dict[str, object] = Field(default_factory=dict, alias="metadata")
class Config:
allow_population_by_field_name = True
created_at: datetime
updated_at: datetime
@@ -56,7 +63,11 @@ class SkillPackRead(SQLModel):
description: str | None = None
source_url: str
branch: str
metadata: dict[str, object]
metadata_: dict[str, object] = Field(default_factory=dict, alias="metadata")
class Config:
allow_population_by_field_name = True
skill_count: int = 0
created_at: datetime
updated_at: datetime

View File

@@ -2,13 +2,16 @@
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.time import utcnow
from app.db import crud
@@ -17,15 +20,14 @@ from app.models.organization_board_access import OrganizationBoardAccess
from app.models.organization_invite_board_access import OrganizationInviteBoardAccess
from app.models.organization_invites import OrganizationInvite
from app.models.organization_members import OrganizationMember
from app.models.skill_packs import SkillPack
from app.models.organizations import Organization
from app.models.skill_packs import SkillPack
from app.models.users import User
if TYPE_CHECKING:
from uuid import UUID
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel.ext.asyncio.session import AsyncSession
from app.schemas.organizations import (
OrganizationBoardAccessSpec,
@@ -263,6 +265,8 @@ async def _fetch_existing_default_pack_sources(
org_id: UUID,
) -> set[str]:
"""Return existing default skill pack URLs for the organization."""
if not isinstance(session, AsyncSession):
return set()
return {
_normalize_skill_pack_source_url(row.source_url)
for row in await SkillPack.objects.filter_by(organization_id=org_id).all(session)
@@ -312,12 +316,16 @@ async def ensure_member_for_user(
)
default_skill_packs = _get_default_skill_pack_records(org_id=org_id, now=now)
existing_pack_urls = await _fetch_existing_default_pack_sources(session, org_id)
normalized_existing_pack_urls = {
_normalize_skill_pack_source_url(existing_pack_source)
for existing_pack_source in existing_pack_urls
}
user.active_organization_id = org_id
session.add(user)
session.add(member)
try:
await session.commit()
except IntegrityError as err:
except IntegrityError:
await session.rollback()
existing_member = await get_first_membership(session, user.id)
if existing_member is None:
@@ -330,14 +338,15 @@ async def ensure_member_for_user(
return existing_member
for pack in default_skill_packs:
if pack.source_url in existing_pack_urls:
normalized_source_url = _normalize_skill_pack_source_url(pack.source_url)
if normalized_source_url in normalized_existing_pack_urls:
continue
session.add(pack)
try:
await session.commit()
except IntegrityError:
await session.rollback()
existing_pack_urls.add(pack.source_url)
normalized_existing_pack_urls.add(normalized_source_url)
continue
await session.refresh(member)