diff --git a/README.md b/README.md index 42066ef..a6f4ad3 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@

FrameX

- Build modular Python services with plug-and-play plugins, clear team boundaries, and transparent API integration. + 🚀 Build scalable Python services with plugins — like FastAPI + Ray, but modular by design.

diff --git a/src/framex/config.py b/src/framex/config.py index 64ba37b..1973212 100644 --- a/src/framex/config.py +++ b/src/framex/config.py @@ -55,6 +55,7 @@ class TestConfig(BaseModel): class OauthConfig(BaseModel): + provider: str = "" client_id: str = "" client_secret: str = "" authorization_url: str = "" @@ -84,6 +85,73 @@ def model_post_init(self, context: Any) -> None: self.jwt_secret = secrets.token_urlsafe(32) +class RepositoryProviderAuthConfig(BaseModel): + token: str = "" + token_header: str = "Authorization" # noqa + token_scheme: str = "Bearer" # noqa + + def build_headers(self) -> dict[str, str]: + if not self.token: + return {} + if self.token_scheme: + return {self.token_header: f"{self.token_scheme} {self.token}"} + return {self.token_header: self.token} + + +class GitLabRepositoryAuthEndpointConfig(RepositoryProviderAuthConfig): + host: str + path_prefix: str = "" + token_header: str = "PRIVATE-TOKEN" # noqa + token_scheme: str = "" + + def matches(self, host: str, path: str) -> bool: + normalized_prefix = self.normalized_path_prefix + if self.host.lower() != host.lower(): + return False + if not normalized_prefix: + return True + return path == normalized_prefix or path.startswith(f"{normalized_prefix}/") + + @property + def normalized_path_prefix(self) -> str: + if not self.path_prefix: + return "" + return self.path_prefix if self.path_prefix.startswith("/") else f"/{self.path_prefix}" + + +class GitLabRepositoryAuthConfig(RepositoryProviderAuthConfig): + token_header: str = "PRIVATE-TOKEN" # noqa + token_scheme: str = "" + endpoints: list[GitLabRepositoryAuthEndpointConfig] = Field(default_factory=list) + + def configured_hosts(self) -> set[str]: + return {endpoint.host.lower() for endpoint in self.endpoints} + + def build_headers_for_url(self, host: str, path: str) -> dict[str, str]: + if endpoint := self.resolve_endpoint(host, path): + return endpoint.build_headers() + return self.build_headers() + + def resolve_endpoint(self, host: str, path: str) -> GitLabRepositoryAuthEndpointConfig | None: + matches = [endpoint for endpoint in self.endpoints if endpoint.matches(host, path)] + if not matches: + return None + return max(matches, key=lambda endpoint: len(endpoint.normalized_path_prefix)) + + +class RepositoryAuthConfig(BaseModel): + github: RepositoryProviderAuthConfig = Field(default_factory=RepositoryProviderAuthConfig) + gitlab: GitLabRepositoryAuthConfig = Field(default_factory=GitLabRepositoryAuthConfig) + + +class RepositoryConfig(BaseModel): + auth: RepositoryAuthConfig = Field(default_factory=RepositoryAuthConfig) + + +class DocsConfig(BaseModel): + embedded_config_file_whitelist: list[str] = Field(default_factory=list) + + class AuthConfig(BaseModel): oauth: OauthConfig | None = Field(default=None) rules: dict[str, list[str]] = Field(default_factory=dict) @@ -131,8 +199,10 @@ class Settings(BaseSettings): load_builtin_plugins: list[str] = Field(default_factory=list) test: TestConfig = Field(default_factory=TestConfig) + docs: DocsConfig = Field(default_factory=DocsConfig) sentry: SentryConfig = Field(default_factory=SentryConfig) auth: AuthConfig = Field(default_factory=AuthConfig) + repository: RepositoryConfig = Field(default_factory=RepositoryConfig) model_config = SettingsConfigDict( # `.env.prod` takes priority over `.env` diff --git a/src/framex/driver/application.py b/src/framex/driver/application.py index 1c730ea..13b5e76 100644 --- a/src/framex/driver/application.py +++ b/src/framex/driver/application.py @@ -5,6 +5,7 @@ from collections.abc import Callable from contextlib import asynccontextmanager from datetime import UTC, datetime +from pathlib import Path from typing import Annotated, Any from zoneinfo import ZoneInfo @@ -23,8 +24,21 @@ from framex.config import settings from framex.consts import API_PRE_STR, DOCS_URL, OPENAPI_URL, PROJECT_NAME, REDOC_URL, VERSION -from framex.driver.auth import authenticate, oauth_callback -from framex.utils import build_swagger_ui_html, format_uptime, safe_error_message +from framex.driver.auth import authenticate, get_auth_payload, oauth_callback +from framex.plugin import get_plugin +from framex.repository import ( + can_access_repository, + get_latest_repository_version, + has_newer_release_version, + is_private_repository, +) +from framex.utils import ( + build_plugin_config_html, + build_swagger_ui_html, + collect_embedded_config_files, + format_uptime, + safe_error_message, +) FRAME_START_TIME = datetime.now(tz=UTC) SHANGHAI_TZ = ZoneInfo("Asia/Shanghai") @@ -109,6 +123,69 @@ async def _on_start(deployment: Any) -> None: async def get_documentation(_: Annotated[str, Depends(authenticate)]) -> HTMLResponse: return build_swagger_ui_html(openapi_url=OPENAPI_URL, title="FrameX Docs") + @application.get("/docs/plugin-config", include_in_schema=False) + async def get_plugin_config_documentation( + request: Request, + plugin: str, + _: Annotated[str, Depends(authenticate)], + ) -> HTMLResponse: + if not settings.auth.oauth: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Plugin config documentation requires auth" + ) + + loaded_plugin = get_plugin(plugin) + auth_payload = get_auth_payload(request) + repo_url = ( + loaded_plugin.metadata.url if loaded_plugin is not None and loaded_plugin.metadata is not None else "" + ) + + if not repo_url or auth_payload is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Repository access denied: {plugin}") + + repository_is_private = is_private_repository(repo_url) + if repository_is_private is not False: + access_result = can_access_repository( + repo_url, + auth_payload.get("oauth_provider"), + auth_payload.get("oauth_access_token"), + ) + if access_result is not True: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=f"Repository access denied: {plugin}" + ) + + loaded_config = loaded_plugin.config.model_dump() if loaded_plugin and loaded_plugin.config else None + config_data = loaded_config or settings.plugins.get(plugin) # type: ignore + if config_data is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Plugin config not found: {plugin}") + + return build_plugin_config_html( + config_data, + collect_embedded_config_files( + config_data, + workspace_root=Path.cwd().resolve(), + whitelist=settings.docs.embedded_config_file_whitelist, + ), + ) + + @application.get("/docs/plugin-release", include_in_schema=False) + async def get_plugin_release_documentation( + plugin: str, + _: Annotated[str, Depends(authenticate)], + ) -> dict[str, Any]: + loaded_plugin = get_plugin(plugin) + if loaded_plugin is None or loaded_plugin.metadata is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Plugin not found: {plugin}") + + current_version = loaded_plugin.metadata.version + current_version = current_version if current_version.startswith("v") else f"v{current_version}" + repo_url = loaded_plugin.metadata.url + latest_version = get_latest_repository_version(repo_url) + if not latest_version or not has_newer_release_version(current_version, latest_version): + return {"has_update": False, "latest_version": None, "repo_url": repo_url} + return {"has_update": True, "latest_version": latest_version, "repo_url": repo_url} + @application.get(REDOC_URL, include_in_schema=False) async def get_redoc_documentation(_: Annotated[str, Depends(authenticate)]) -> HTMLResponse: return get_redoc_html(openapi_url=OPENAPI_URL, title="FrameX Redoc") diff --git a/src/framex/driver/auth.py b/src/framex/driver/auth.py index 4199c73..0f3d7ee 100644 --- a/src/framex/driver/auth.py +++ b/src/framex/driver/auth.py @@ -1,4 +1,6 @@ from datetime import UTC, datetime, timedelta +from secrets import token_urlsafe +from typing import Any import httpx import jwt @@ -11,57 +13,97 @@ from framex.consts import AUTH_COOKIE_NAME, DOCS_URL api_key_header = APIKeyHeader(name="Authorization", auto_error=False) +SESSION_LIFETIME = timedelta(hours=24) +_AUTH_SESSIONS: dict[str, dict[str, Any]] = {} -def create_jwt(payload: dict) -> str: +def _now_utc() -> datetime: + return datetime.now(UTC) + + +def _purge_expired_sessions(now_utc: datetime | None = None) -> None: + current = now_utc or _now_utc() + expired_session_ids = [ + session_id for session_id, payload in _AUTH_SESSIONS.items() if payload.get("expires_at", current) <= current + ] + for session_id in expired_session_ids: + _AUTH_SESSIONS.pop(session_id, None) + + +def create_jwt(payload: dict[str, Any]) -> str: if not settings.auth.oauth: raise RuntimeError("OAuth not configured") - now_utc = datetime.now(UTC) + now_utc = _now_utc() + token_payload = { + **payload, + "iat": int(now_utc.timestamp()), + "exp": int((now_utc + SESSION_LIFETIME).timestamp()), + } + return jwt.encode(token_payload, settings.auth.oauth.jwt_secret, algorithm=settings.auth.oauth.jwt_algorithm) - payload.update( - { - "iat": int(now_utc.timestamp()), - "exp": int((now_utc + timedelta(hours=24)).timestamp()), - } - ) - return jwt.encode(payload, settings.auth.oauth.jwt_secret, algorithm=settings.auth.oauth.jwt_algorithm) +def create_auth_session(session_payload: dict[str, Any]) -> str: + now_utc = _now_utc() + expires_at = now_utc + SESSION_LIFETIME + session_id = token_urlsafe(32) + _AUTH_SESSIONS[session_id] = { + **session_payload, + "expires_at": expires_at, + } + _purge_expired_sessions(now_utc) + return session_id -def auth_jwt(request: Request) -> bool: - if not settings.auth.oauth: - return False - token = request.cookies.get(AUTH_COOKIE_NAME) - if not token: - return False +def decode_auth_token(token: str | None) -> dict[str, Any] | None: + if not settings.auth.oauth or not token: + return None try: - jwt.decode( + payload = jwt.decode( token, settings.auth.oauth.jwt_secret, algorithms=[settings.auth.oauth.jwt_algorithm], ) - return True except (jwt.InvalidTokenError, jwt.ExpiredSignatureError): - return False + return None + if not isinstance(payload, dict): + return None -def authenticate(request: Request, api_key: str | None = Depends(api_key_header)) -> None: - if settings.auth.oauth: - if token := request.cookies.get(AUTH_COOKIE_NAME): - try: - jwt.decode( - token, - settings.auth.oauth.jwt_secret, - algorithms=[settings.auth.oauth.jwt_algorithm], - ) - return + session_id = payload.get("session_id") + if not isinstance(session_id, str) or not session_id: + return None - except Exception as e: - from framex.log import logger + now_utc = _now_utc() + _purge_expired_sessions(now_utc) + session_payload = _AUTH_SESSIONS.get(session_id) + if session_payload is None: + return None - logger.warning(f"JWT decode failed: {e}") + expires_at = session_payload.get("expires_at") + if not isinstance(expires_at, datetime) or expires_at <= now_utc: + _AUTH_SESSIONS.pop(session_id, None) + return None + + return { + **payload, + **{key: value for key, value in session_payload.items() if key != "expires_at"}, + } + + +def get_auth_payload(request: Request) -> dict[str, Any] | None: + return decode_auth_token(request.cookies.get(AUTH_COOKIE_NAME)) + + +def auth_jwt(request: Request) -> bool: + return get_auth_payload(request) is not None + + +def authenticate(request: Request, api_key: str | None = Depends(api_key_header)) -> None: + if settings.auth.oauth: + if get_auth_payload(request) is not None: + return if api_key and api_key in (settings.auth.get_auth_keys(request.url.path) or []): return @@ -74,7 +116,7 @@ def authenticate(request: Request, api_key: str | None = Depends(api_key_header) f"?client_id={settings.auth.oauth.client_id}" "&response_type=code" f"&redirect_uri={settings.auth.oauth.call_back_url}" - "&scope=read_user" + "&scope=read_user%20read_api" ) }, ) @@ -116,12 +158,23 @@ async def oauth_callback(code: str) -> Response: "message": f"Welcome {username}", "username": username, "email": user_resp.get("email"), + "oauth_provider": settings.auth.oauth.provider, + "oauth_access_token": auth_token, } + session_id = create_auth_session(user_info) res = RedirectResponse(url=DOCS_URL, status_code=status.HTTP_302_FOUND) res.set_cookie( AUTH_COOKIE_NAME, - create_jwt(user_info), + create_jwt( + { + "message": user_info["message"], + "username": username, + "email": user_info["email"], + "oauth_provider": settings.auth.oauth.provider, + "session_id": session_id, + } + ), httponly=True, samesite="lax", ) diff --git a/src/framex/plugin/on.py b/src/framex/plugin/on.py index d418a1f..a4c9880 100644 --- a/src/framex/plugin/on.py +++ b/src/framex/plugin/on.py @@ -57,6 +57,7 @@ def decorator(cls: type) -> type: version, plugin.module.__plugin_meta__.description, plugin.module.__plugin_meta__.url, + plugin.name, ) plugin_apis.append( diff --git a/src/framex/plugins/proxy/__init__.py b/src/framex/plugins/proxy/__init__.py index dfe2c63..78d1aa2 100644 --- a/src/framex/plugins/proxy/__init__.py +++ b/src/framex/plugins/proxy/__init__.py @@ -174,6 +174,7 @@ async def _parse_openai_docs(self, url: str) -> None: f"v{__plugin_meta__.version}", __plugin_meta__.description, __plugin_meta__.url, + __plugin_meta__.name, ) await adapter.call_func( plugin_api, diff --git a/src/framex/repository/__init__.py b/src/framex/repository/__init__.py new file mode 100644 index 0000000..3cbd378 --- /dev/null +++ b/src/framex/repository/__init__.py @@ -0,0 +1,13 @@ +from .versioning import ( + can_access_repository, + get_latest_repository_version, + has_newer_release_version, + is_private_repository, +) + +__all__ = [ + "can_access_repository", + "get_latest_repository_version", + "has_newer_release_version", + "is_private_repository", +] diff --git a/src/framex/repository/providers/__init__.py b/src/framex/repository/providers/__init__.py new file mode 100644 index 0000000..e724da5 --- /dev/null +++ b/src/framex/repository/providers/__init__.py @@ -0,0 +1,17 @@ +"""Repository hosting providers used by version lookup.""" + +from .base import RepositoryVersionProvider +from .github import GITHUB_PROVIDER +from .gitlab import GITLAB_PROVIDER + +REPOSITORY_VERSION_PROVIDERS: tuple[RepositoryVersionProvider, ...] = ( + GITHUB_PROVIDER, + GITLAB_PROVIDER, +) + +__all__ = [ + "GITHUB_PROVIDER", + "GITLAB_PROVIDER", + "REPOSITORY_VERSION_PROVIDERS", + "RepositoryVersionProvider", +] diff --git a/src/framex/repository/providers/base.py b/src/framex/repository/providers/base.py new file mode 100644 index 0000000..b351484 --- /dev/null +++ b/src/framex/repository/providers/base.py @@ -0,0 +1,93 @@ +"""Base types and shared helpers for repository version providers.""" + +from abc import ABC, abstractmethod +from typing import Any +from urllib.parse import ParseResult + +import httpx + +DEFAULT_HTTP_TIMEOUT = 2.0 + + +class RepositoryVersionProvider(ABC): + """Abstract interface for repository hosting providers.""" + + name: str + + @abstractmethod + def matches(self, parsed_url: ParseResult) -> bool: + """Return whether this provider can handle the parsed repository URL.""" + + @abstractmethod + def get_latest_version(self, parsed_url: ParseResult) -> str | None: + """Return the latest published version for the repository URL.""" + + def has_repository_access(self, parsed_url: ParseResult, access_token: str) -> bool: + """Return whether the given user token can access the repository URL.""" + + raise NotImplementedError("RepositoryVersionProvider does not implement access checking") + + def is_public_repository(self, parsed_url: ParseResult) -> bool | None: + """Return whether the repository is publicly accessible without authentication.""" + + raise NotImplementedError("RepositoryVersionProvider does not implement public repository checking") + + @staticmethod + def extract_repository_parts(parsed_url: ParseResult) -> list[str]: + """Split a repository path into normalized URL parts.""" + + parts = [part for part in parsed_url.path.split("/") if part] + if parts: + parts[-1] = parts[-1].removesuffix(".git") + return parts + + @staticmethod + def fetch_json(url: str, headers: dict[str, str] | None = None) -> dict[str, Any] | None: + """Fetch a JSON object and return `None` when it cannot be consumed.""" + + try: + with httpx.Client(timeout=DEFAULT_HTTP_TIMEOUT, headers=headers) as client: + response = client.get(url, follow_redirects=True) + except httpx.HTTPError: + return None + + if response.status_code != 200: + return None + + try: + payload = response.json() + except ValueError: + return None + + return payload if isinstance(payload, dict) else None + + @staticmethod + def can_fetch( + url: str, + headers: dict[str, str] | None = None, + *, + follow_redirects: bool = True, + ) -> bool: + """Return whether the resource can be fetched successfully.""" + + try: + with httpx.Client(timeout=DEFAULT_HTTP_TIMEOUT, headers=headers) as client: + response = client.get(url, follow_redirects=follow_redirects) + except httpx.HTTPError: + return False + + return response.status_code == 200 + + @staticmethod + def extract_version(payload: dict[str, Any] | None) -> str | None: + """Read a version-like string from a release payload.""" + + if payload is None: + return None + + latest_version = payload.get("tag_name") or payload.get("name") + if not isinstance(latest_version, str): + return None + + latest_version = latest_version.strip() + return latest_version or None diff --git a/src/framex/repository/providers/github.py b/src/framex/repository/providers/github.py new file mode 100644 index 0000000..c8ebe78 --- /dev/null +++ b/src/framex/repository/providers/github.py @@ -0,0 +1,63 @@ +"""GitHub repository version provider.""" + +from urllib.parse import ParseResult + +from framex.config import settings + +from .base import RepositoryVersionProvider + +GITHUB_HOSTS = frozenset({"github.com", "www.github.com"}) +GITHUB_API_HEADERS = { + "Accept": "application/vnd.github+json", + "User-Agent": "framex-docs", +} + + +class GitHubRepositoryVersionProvider(RepositoryVersionProvider): + """Resolve latest release versions for GitHub repositories.""" + + name = "github" + + def matches(self, parsed_url: ParseResult) -> bool: + return parsed_url.netloc in GITHUB_HOSTS + + def get_latest_version(self, parsed_url: ParseResult) -> str | None: + repository = self._extract_owner_and_repository(parsed_url) + if repository is None: + return None + + owner, repo = repository + api_url = f"https://api.github.com/repos/{owner}/{repo}/releases/latest" + headers = {**GITHUB_API_HEADERS, **settings.repository.auth.github.build_headers()} + payload = self.fetch_json(api_url, headers=headers) + return self.extract_version(payload) + + def is_public_repository(self, parsed_url: ParseResult) -> bool | None: + repository = self._extract_owner_and_repository(parsed_url) + if repository is None: + return None + + owner, repo = repository + api_url = f"https://api.github.com/repos/{owner}/{repo}" + return self.can_fetch(api_url, headers=GITHUB_API_HEADERS) + + def has_repository_access(self, parsed_url: ParseResult, access_token: str) -> bool: + repository = self._extract_owner_and_repository(parsed_url) + if repository is None or not access_token: + return False + + owner, repo = repository + api_url = f"https://api.github.com/repos/{owner}/{repo}" + headers = {**GITHUB_API_HEADERS, "Authorization": f"Bearer {access_token}"} + return self.can_fetch(api_url, headers=headers) + + def _extract_owner_and_repository(self, parsed_url: ParseResult) -> tuple[str, str] | None: + parts = self.extract_repository_parts(parsed_url) + if len(parts) < 2: + return None + return parts[0], parts[1] + + +GITHUB_PROVIDER = GitHubRepositoryVersionProvider() + +__all__ = ["GITHUB_PROVIDER", "GitHubRepositoryVersionProvider"] diff --git a/src/framex/repository/providers/gitlab.py b/src/framex/repository/providers/gitlab.py new file mode 100644 index 0000000..c0f85d3 --- /dev/null +++ b/src/framex/repository/providers/gitlab.py @@ -0,0 +1,91 @@ +"""GitLab repository version provider.""" + +from urllib.parse import ParseResult, quote + +from framex.config import settings + +from .base import RepositoryVersionProvider + +GITLAB_PRIMARY_HOST = "gitlab.com" +GITLAB_HOST_SUFFIX = ".gitlab.com" +GITLAB_RESERVED_PATH_MARKERS = {"-", "tree", "blob", "raw", "commits", "branches", "tags", "merge_requests"} + + +class GitLabRepositoryVersionProvider(RepositoryVersionProvider): + """Resolve latest release versions for GitLab repositories.""" + + name = "gitlab" + + def matches(self, parsed_url: ParseResult) -> bool: + host = parsed_url.netloc.lower() + return ( + host == GITLAB_PRIMARY_HOST + or host.endswith(GITLAB_HOST_SUFFIX) + or host in settings.repository.auth.gitlab.configured_hosts() + ) + + def get_latest_version(self, parsed_url: ParseResult) -> str | None: + headers = settings.repository.auth.gitlab.build_headers_for_url(parsed_url.netloc, parsed_url.path) or None + project_path = self._resolve_project_path(parsed_url, headers=headers, require_fetch=False) + if project_path is None: + return None + + api_url = self._build_release_api_url(parsed_url, project_path) + payload = self.fetch_json(api_url, headers=headers) + return self.extract_version(payload) + + def is_public_repository(self, parsed_url: ParseResult) -> bool | None: + return self.can_fetch(self._build_repository_web_url(parsed_url), follow_redirects=False) + + def has_repository_access(self, parsed_url: ParseResult, access_token: str) -> bool: + if not access_token: + return False + + headers = {"Authorization": f"Bearer {access_token}"} + return self._resolve_project_path(parsed_url, headers=headers, require_fetch=True) is not None + + def _resolve_project_path( + self, + parsed_url: ParseResult, + headers: dict[str, str] | None = None, + require_fetch: bool = True, + ) -> str | None: + candidates = self._iter_project_path_candidates(parsed_url) + if not candidates: + return None + if not require_fetch and len(candidates) == 1: + return candidates[0] + + for candidate in candidates: + if self.can_fetch(self._build_project_api_url(parsed_url, candidate), headers=headers): + return candidate + return None + + def _iter_project_path_candidates(self, parsed_url: ParseResult) -> list[str]: + parts = self.extract_repository_parts(parsed_url) + if len(parts) < 2: + return [] + + marker_index = next((index for index, part in enumerate(parts) if part in GITLAB_RESERVED_PATH_MARKERS), None) + max_length = marker_index if marker_index is not None else len(parts) + if max_length < 2: + return [] + + return ["/".join(parts[:length]) for length in range(max_length, 1, -1)] + + def _build_project_api_url(self, parsed_url: ParseResult, project_path: str) -> str: + project_id = quote(project_path, safe="") + return f"{parsed_url.scheme}://{parsed_url.netloc}/api/v4/projects/{project_id}" + + def _build_release_api_url(self, parsed_url: ParseResult, project_path: str) -> str: + project_id = quote(project_path, safe="") + return f"{parsed_url.scheme}://{parsed_url.netloc}/api/v4/projects/{project_id}/releases/permalink/latest" + + def _build_repository_web_url(self, parsed_url: ParseResult) -> str: + normalized = parsed_url._replace(params="", query="", fragment="") + return normalized.geturl() + + +GITLAB_PROVIDER = GitLabRepositoryVersionProvider() + +__all__ = ["GITLAB_PROVIDER", "GitLabRepositoryVersionProvider"] diff --git a/src/framex/repository/versioning.py b/src/framex/repository/versioning.py new file mode 100644 index 0000000..53c2c33 --- /dev/null +++ b/src/framex/repository/versioning.py @@ -0,0 +1,68 @@ +"""Repository version lookup entrypoints.""" + +import re +from functools import lru_cache +from urllib.parse import ParseResult, urlparse + +from framex.repository.providers.base import RepositoryVersionProvider + +from .providers import REPOSITORY_VERSION_PROVIDERS + +VERSION_PATTERN = re.compile(r"\d+(?:\.\d+)*") + + +def _get_provider_for_url(repo_url: str) -> tuple[RepositoryVersionProvider | None, ParseResult]: + parsed_url = urlparse(repo_url) + provider = next((provider for provider in REPOSITORY_VERSION_PROVIDERS if provider.matches(parsed_url)), None) + return provider, parsed_url + + +@lru_cache(maxsize=128) +def get_latest_repository_version(repo_url: str) -> str | None: + provider, parsed_url = _get_provider_for_url(repo_url) + if provider is None: + return None + return provider.get_latest_version(parsed_url) + + +def is_private_repository(repo_url: str) -> bool | None: + provider, parsed_url = _get_provider_for_url(repo_url) + if provider is None: + return None + + is_public = provider.is_public_repository(parsed_url) + if is_public is None: + return None + return not is_public + + +def can_access_repository(repo_url: str, provider_name: str | None, access_token: str | None) -> bool | None: + provider, parsed_url = _get_provider_for_url(repo_url) + if provider is None: + return None + + if not provider_name or provider.name != provider_name: + return None + if not access_token: + return False + + return provider.has_repository_access(parsed_url, access_token) + + +def has_newer_release_version(current_version: str, latest_version: str) -> bool: + current_parts = _normalize_version(current_version) + latest_parts = _normalize_version(latest_version) + if current_parts is None or latest_parts is None: + return False + + max_length = max(len(current_parts), len(latest_parts)) + current_padded = current_parts + (0,) * (max_length - len(current_parts)) + latest_padded = latest_parts + (0,) * (max_length - len(latest_parts)) + return latest_padded > current_padded + + +def _normalize_version(version: str) -> tuple[int, ...] | None: + match = VERSION_PATTERN.search(version) + if match is None: + return None + return tuple(int(part) for part in match.group(0).split(".")) diff --git a/src/framex/utils/__init__.py b/src/framex/utils/__init__.py new file mode 100644 index 0000000..05c8324 --- /dev/null +++ b/src/framex/utils/__init__.py @@ -0,0 +1,41 @@ +from .cache import cache_decode, cache_encode +from .common import ( + StreamEnventType, + escape_tag, + extract_method_params, + format_uptime, + make_stream_event, + path_to_module_name, + plugin_to_deployment_name, + safe_error_message, + shorten_str, +) +from .config_docs import ( + build_plugin_config_html, + collect_embedded_config_files, + mask_sensitive_config_data, + mask_sensitive_config_text, + mask_sensitive_embedded_config_content, +) +from .docs import build_plugin_description, build_swagger_ui_html + +__all__ = [ + "StreamEnventType", + "build_plugin_config_html", + "build_plugin_description", + "build_swagger_ui_html", + "cache_decode", + "cache_encode", + "collect_embedded_config_files", + "escape_tag", + "extract_method_params", + "format_uptime", + "make_stream_event", + "mask_sensitive_config_data", + "mask_sensitive_config_text", + "mask_sensitive_embedded_config_content", + "path_to_module_name", + "plugin_to_deployment_name", + "safe_error_message", + "shorten_str", +] diff --git a/src/framex/utils/cache.py b/src/framex/utils/cache.py new file mode 100644 index 0000000..fd4a9e1 --- /dev/null +++ b/src/framex/utils/cache.py @@ -0,0 +1,80 @@ +import base64 +import importlib +import json +import zlib +from datetime import datetime +from enum import Enum +from itertools import cycle +from typing import Any + + +def xor_crypt(data: bytes, key: str = "01234567890abcdefghijklmnopqrstuvwxyz") -> bytes: + return bytes(a ^ b for a, b in zip(data, cycle(key.encode()))) + + +def cache_encode(data: Any) -> str: + def transform(obj: Any) -> Any: + if hasattr(obj, "__dict__"): + raw_attributes = {k: transform(v) for k, v in obj.__dict__.items() if not k.startswith("_")} + return { + "__type__": "dynamic_obj", + "__module__": obj.__class__.__module__, + "__class__": obj.__class__.__name__, + "data": raw_attributes, + } + if isinstance(obj, list): + return [transform(i) for i in obj] + if isinstance(obj, dict): + return {k: transform(v) for k, v in obj.items()} + if isinstance(obj, datetime): + return obj.isoformat() + if isinstance(obj, Enum): + return obj.value + return obj + + json_str = json.dumps(transform(data), ensure_ascii=False) + compressed = zlib.compress(json_str.encode("utf-8")) + encrypted = xor_crypt(compressed) + return base64.b64encode(encrypted).decode("ascii") + + +def cache_decode(res: Any) -> Any: + current = res + while isinstance(current, str): + try: + decoded_bytes = base64.b64decode(current, validate=True) + current = zlib.decompress(xor_crypt(decoded_bytes)).decode("utf-8") + except Exception: + try: + temp = json.loads(current) + if temp == current: + break + current = temp + except Exception: + break + + def restore_models(item: Any) -> Any: + if isinstance(item, list): + return [restore_models(i) for i in item] + + if isinstance(item, dict): + if item.get("__type__") == "dynamic_obj": + try: + module = importlib.import_module(item["__module__"]) + cls = getattr(module, item["__class__"]) + + cleaned_data = {k: restore_models(v) for k, v in item["data"].items()} + + if hasattr(cls, "model_validate"): + return cls.model_validate(cleaned_data) + return cls(**cleaned_data) + except Exception: + from types import SimpleNamespace + + return SimpleNamespace(**{k: restore_models(v) for k, v in item["data"].items()}) + + return {k: restore_models(v) for k, v in item.items()} + + return item + + return restore_models(current) diff --git a/src/framex/utils/common.py b/src/framex/utils/common.py new file mode 100644 index 0000000..66cf271 --- /dev/null +++ b/src/framex/utils/common.py @@ -0,0 +1,86 @@ +import inspect +import json +import re +from collections.abc import Callable +from datetime import timedelta +from enum import StrEnum +from pathlib import Path +from typing import Any + +from pydantic import BaseModel + + +def plugin_to_deployment_name(plugin_name: str, obj_name: str) -> str: + return f"{plugin_name}.{obj_name}" + + +def path_to_module_name(path: Path) -> str: + """Convert path to module name.""" + rel_path = path.resolve().relative_to(Path.cwd().resolve()) + if rel_path.stem == "__init__": + module_name = ".".join(rel_path.parts[:-1]) + else: + module_name = ".".join([*rel_path.parts[:-1], rel_path.stem]) # type: ignore[arg-type] + return module_name.removeprefix("src.") + + +def escape_tag(s: str) -> str: + """Escape like markers used in colored logs.""" + return re.sub(r"\s]*)>", r"\\\g<0>", s) + + +def extract_method_params(func: Callable) -> list[tuple[str, Any]]: + sig = inspect.signature(func) + params: list[tuple[str, Any]] = [] + for param in sig.parameters.values(): + if param.name == "self": + continue + params.append((param.name, param.annotation)) + return params + + +class StreamEnventType(StrEnum): + MESSAGE_CHUNK = "message_chunk" + FINISH = "finish" + ERROR = "error" + DEBUG = "debug" + + +def make_stream_event(event_type: StreamEnventType | str, data: str | dict[str, Any] | BaseModel | None = None) -> str: + if not data: + data = {} + elif isinstance(data, BaseModel): + data = data.model_dump() + elif isinstance(data, str): + data = {"content": data} + return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + + +def format_uptime(delta: timedelta) -> str: + days = delta.days + hours, remainder = divmod(delta.seconds, 3600) + minutes, seconds = divmod(remainder, 60) + + parts: list[str] = [] + if days: + parts.append(f"{days}d") + if hours: + parts.append(f"{hours}h") + if minutes: + parts.append(f"{minutes}m") + if seconds or not parts: + parts.append(f"{seconds}s") + + return " ".join(parts) + + +def safe_error_message(e: Exception) -> str: + if hasattr(e, "cause") and e.cause: + return str(e.cause) + if e.args: + return str(e.args[0]) + return "Internal Server Error" + + +def shorten_str(data: str, max_len: int = 45) -> str: + return data if len(data) <= max_len else data[: max_len - 3] + "..." diff --git a/src/framex/utils/config_docs.py b/src/framex/utils/config_docs.py new file mode 100644 index 0000000..2972e8e --- /dev/null +++ b/src/framex/utils/config_docs.py @@ -0,0 +1,390 @@ +import html +import json +import re +import tomllib +from collections.abc import Sequence +from fnmatch import fnmatch +from pathlib import Path +from typing import Any + +import yaml +from fastapi.responses import HTMLResponse + +from framex.config import settings + +SUPPORTED_EMBEDDED_CONFIG_SUFFIXES = (".yaml", ".yml", ".toml") + +SENSITIVE_CONFIG_KEYWORDS = ( + "token", + "secret", + "password", + "passwd", + "authorization", + "api_key", + "apikey", + "access_key", + "private_key", + "client_secret", + "cookie", + "session", + "credential", +) + + +def _normalize_whitelist_pattern(pattern: str) -> str: + return pattern.strip().lstrip("/") + + +def _is_whitelisted_embedded_config_path(candidate: Path, workspace_root: Path, whitelist: Sequence[str]) -> bool: + if not whitelist: + return False + + relative_path = candidate.relative_to(workspace_root).as_posix() + return any( + fnmatch(relative_path, _normalize_whitelist_pattern(pattern)) for pattern in whitelist if pattern.strip() + ) + + +def _resolve_embedded_config_path( + path_value: str, + workspace_root: Path, + whitelist: Sequence[str], +) -> Path | None: + candidate = Path(path_value).expanduser() + if not candidate.is_absolute(): # noqa + candidate = (workspace_root / candidate).resolve() + else: + candidate = candidate.resolve() + + if candidate.suffix.lower() not in SUPPORTED_EMBEDDED_CONFIG_SUFFIXES: + return None + if not candidate.is_file(): + return None + + try: + candidate.relative_to(workspace_root) + except ValueError: + return None + + if not _is_whitelisted_embedded_config_path(candidate, workspace_root, whitelist): + return None + + return candidate + + +def collect_embedded_config_files( + config_data: Any, + workspace_root: Path | None = None, + whitelist: Sequence[str] = (), +) -> list[tuple[str, str]]: + found_files: list[tuple[str, str]] = [] + visited_paths: set[Path] = set() + resolved_workspace_root = (workspace_root or Path.cwd()).resolve() + + def walk(value: Any) -> None: + if isinstance(value, dict): + for nested_value in value.values(): + walk(nested_value) + return + if isinstance(value, list): + for item in value: + walk(item) + return + if not isinstance(value, str): + return + + resolved_path = _resolve_embedded_config_path(value, resolved_workspace_root, whitelist) + if resolved_path is None or resolved_path in visited_paths: + return + + visited_paths.add(resolved_path) + found_files.append((str(resolved_path), resolved_path.read_text(encoding="utf-8"))) + + walk(config_data) + return found_files + + +def _format_toml_key(key: str) -> str: + if re.fullmatch(r"[A-Za-z0-9_-]+", key): + return key + return json.dumps(key, ensure_ascii=False) + + +def _mask_sensitive_string(value: str) -> str: + if not value: + return value + if len(value) <= 4: + return "****" + return f"{value[:2]}{'*' * max(len(value) - 4, 4)}{value[-2:]}" + + +def _is_sensitive_config_key(key: str) -> bool: + normalized_key = key.lower().replace("-", "_") + return any(keyword in normalized_key for keyword in SENSITIVE_CONFIG_KEYWORDS) + + +def _should_mask_config_path(key_path: tuple[str, ...]) -> bool: + if any(_is_sensitive_config_key(segment) for segment in key_path): + return True + if len(key_path) >= 2 and key_path[-2] == "rules" and "auth" in key_path: # noqa + return True + return False + + +def mask_sensitive_config_data(config_data: Any, key_path: tuple[str, ...] = ()) -> Any: + if isinstance(config_data, dict): + return { + key: mask_sensitive_config_data(value, key_path=(*key_path, str(key))) + for key, value in config_data.items() + } + if isinstance(config_data, list): + return [mask_sensitive_config_data(item, key_path=key_path) for item in config_data] + if isinstance(config_data, str) and _should_mask_config_path(key_path): + return _mask_sensitive_string(config_data) + return config_data + + +def mask_sensitive_config_text(content: str) -> str: + lines: list[str] = [] + pattern = re.compile( + r"^(?P\s*(?:-\s*)?[\"']?(?P[A-Za-z0-9_.-]+)[\"']?\s*(?::|=)\s*)(?P.*?)(?P\s*(?:#.*)?)$" + ) + + for line in content.splitlines(): + match = pattern.match(line) + if not match or not _is_sensitive_config_key(match.group("key")): + lines.append(line) + continue + + raw_value = match.group("value").strip() + if not raw_value: + lines.append(line) + continue + + quote_char = "" + if raw_value[0] in {'"', "'"} and raw_value[-1] == raw_value[0]: + quote_char = raw_value[0] + inner_value = raw_value[1:-1] + else: + inner_value = raw_value + + masked_value = _mask_sensitive_string(inner_value) + rendered_value = f"{quote_char}{masked_value}{quote_char}" if quote_char else masked_value + lines.append(f"{match.group('prefix')}{rendered_value}{match.group('suffix')}") + + return "\n".join(lines) + + +def _format_toml_value(value: Any) -> str: + if isinstance(value, bool): + return str(value).lower() + if isinstance(value, str): + return json.dumps(value, ensure_ascii=False) + if isinstance(value, int | float): + return str(value) + if value is None: + return '""' + if isinstance(value, list): + return f"[{', '.join(_format_toml_value(item) for item in value)}]" + if isinstance(value, dict): + items = ", ".join(f"{_format_toml_key(str(key))} = {_format_toml_value(item)}" for key, item in value.items()) + return f"{{ {items} }}" + return json.dumps(value, ensure_ascii=False) + + +def _dump_toml_table(data: dict[str, Any], prefix: tuple[str, ...] = ()) -> list[str]: + lines: list[str] = [] + nested_items: list[tuple[str, Any]] = [] + + for key, value in data.items(): + if isinstance(value, dict): + nested_items.append((key, value)) + continue + if isinstance(value, list) and value and all(isinstance(item, dict) for item in value): + nested_items.append((key, value)) + continue + lines.append(f"{_format_toml_key(str(key))} = {_format_toml_value(value)}") + + for key, value in nested_items: + section_name = ".".join([*prefix, _format_toml_key(str(key))]) + if isinstance(value, dict): + if lines: + lines.append("") + lines.append(f"[{section_name}]") + lines.extend(_dump_toml_table(value, (*prefix, _format_toml_key(str(key))))) + continue + + for item in value: + if lines: + lines.append("") + lines.append(f"[[{section_name}]]") + lines.extend(_dump_toml_table(item, (*prefix, _format_toml_key(str(key))))) + + return lines + + +def _format_plugin_config_toml(config_data: Any) -> str: + if not isinstance(config_data, dict): + return _format_toml_value(config_data) + return "\n".join(_dump_toml_table(config_data)) + + +def _normalize_display_config_paths( + config_data: Any, + workspace_root: Path | None = None, + whitelist: Sequence[str] = (), +) -> Any: + resolved_workspace_root = (workspace_root or Path.cwd()).resolve() + if isinstance(config_data, dict): + return { + key: _normalize_display_config_paths( + value, + workspace_root=resolved_workspace_root, + whitelist=whitelist, + ) + for key, value in config_data.items() + } + if isinstance(config_data, list): + return [ + _normalize_display_config_paths( + item, + workspace_root=resolved_workspace_root, + whitelist=whitelist, + ) + for item in config_data + ] + if isinstance(config_data, str): + resolved_path = _resolve_embedded_config_path( + config_data, + resolved_workspace_root, + whitelist, + ) + if resolved_path is not None: + return _to_display_embedded_config_path(str(resolved_path), workspace_root=resolved_workspace_root) + + candidate = Path(config_data).expanduser() + try: + if not candidate.is_absolute(): + candidate = (resolved_workspace_root / candidate).resolve() + else: + candidate = candidate.resolve() + except OSError: + return config_data + + if candidate.suffix.lower() in SUPPORTED_EMBEDDED_CONFIG_SUFFIXES and candidate.is_file(): + return "[restricted config path]" + return config_data + + +def mask_sensitive_embedded_config_content(file_path: str, content: str) -> str: + suffix = Path(file_path).suffix.lower() + + try: + if suffix == ".toml": + parsed = tomllib.loads(content) + return _format_plugin_config_toml(mask_sensitive_config_data(parsed)) + if suffix in {".yaml", ".yml"}: + parsed = yaml.safe_load(content) + masked = mask_sensitive_config_data(parsed) + return yaml.safe_dump(masked, allow_unicode=True, sort_keys=False).rstrip() + except Exception: + return mask_sensitive_config_text(content) + + return mask_sensitive_config_text(content) + + +def _to_display_embedded_config_path(file_path: str, workspace_root: Path | None = None) -> str: + resolved_workspace_root = (workspace_root or Path.cwd()).resolve() + resolved_file_path = Path(file_path).resolve() + try: + return resolved_file_path.relative_to(resolved_workspace_root).as_posix() + except ValueError: + return resolved_file_path.as_posix() + + +def build_plugin_config_html(config_data: Any, embedded_files: list[tuple[str, str]] | None = None) -> HTMLResponse: + workspace_root = Path.cwd().resolve() + embedded_path_whitelist = tuple(settings.docs.embedded_config_file_whitelist or []) + normalized_config_data = _normalize_display_config_paths( + config_data, + workspace_root=workspace_root, + whitelist=embedded_path_whitelist, + ) + masked_config_data = mask_sensitive_config_data(normalized_config_data) + escaped_toml: str = html.escape(_format_plugin_config_toml(masked_config_data)) + masked_embedded_files = [ + ( + _to_display_embedded_config_path(file_path), + mask_sensitive_embedded_config_content(file_path, file_content), + ) + for file_path, file_content in (embedded_files or []) + ] + embedded_sections = "".join( + f""" +

\n
\n

Referenced Config: {html.escape(file_path)}

\n
\n
{html.escape(file_content)}
\n
+ """ + for file_path, file_content in masked_embedded_files + ) + return HTMLResponse( + f""" + + + + + + Plugin Config + + + +
+
+

Plugin Config (TOML)

+
+
{escaped_toml}
+
+ {embedded_sections} + + + """ + ) diff --git a/src/framex/utils.py b/src/framex/utils/docs.py similarity index 66% rename from src/framex/utils.py rename to src/framex/utils/docs.py index 044018f..bb2de6e 100644 --- a/src/framex/utils.py +++ b/src/framex/utils/docs.py @@ -1,166 +1,8 @@ -import base64 -import importlib -import inspect -import json -import re -import zlib -from collections.abc import Callable -from datetime import datetime, timedelta -from enum import Enum, StrEnum -from itertools import cycle -from pathlib import Path -from typing import Any +from urllib.parse import quote from fastapi.responses import HTMLResponse -from pydantic import BaseModel - - -def plugin_to_deployment_name(plugin_name: str, obj_name: str) -> str: - return f"{plugin_name}.{obj_name}" - - -def path_to_module_name(path: Path) -> str: - """Convert path to module name""" - rel_path = path.resolve().relative_to(Path.cwd().resolve()) - if rel_path.stem == "__init__": - module_name = ".".join(rel_path.parts[:-1]) - else: - module_name = ".".join([*rel_path.parts[:-1], rel_path.stem]) # type: ignore - return module_name.removeprefix("src.") - - -def escape_tag(s: str) -> str: - """Used to escape `` type special tags when recording color logs""" - return re.sub(r"\s]*)>", r"\\\g<0>", s) - - -def extract_method_params(func: Callable) -> list[tuple[str, Any]]: - sig = inspect.signature(func) - params = [] - for param in sig.parameters.values(): - if param.name == "self": - continue - params.append((param.name, param.annotation)) - return params - - -class StreamEnventType(StrEnum): - MESSAGE_CHUNK = "message_chunk" - FINISH = "finish" - ERROR = "error" - DEBUG = "debug" - - -def make_stream_event(event_type: StreamEnventType | str, data: str | dict[str, Any] | BaseModel | None = None) -> str: - if not data: - data = {} - elif isinstance(data, BaseModel): - data = data.model_dump() - elif isinstance(data, str): - data = {"content": data} - return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" - - -def xor_crypt(data: bytes, key: str = "01234567890abcdefghijklmnopqrstuvwxyz") -> bytes: - return bytes(a ^ b for a, b in zip(data, cycle(key.encode()))) - -def cache_encode(data: Any) -> str: - def transform(obj: Any) -> Any: - if hasattr(obj, "__dict__"): - raw_attributes = {k: transform(v) for k, v in obj.__dict__.items() if not k.startswith("_")} - return { - "__type__": "dynamic_obj", - "__module__": obj.__class__.__module__, - "__class__": obj.__class__.__name__, - "data": raw_attributes, - } - if isinstance(obj, list): - return [transform(i) for i in obj] - if isinstance(obj, dict): - return {k: transform(v) for k, v in obj.items()} - if isinstance(obj, datetime): - return obj.isoformat() - if isinstance(obj, Enum): - return obj.value - return obj - - json_str = json.dumps(transform(data), ensure_ascii=False) - compressed = zlib.compress(json_str.encode("utf-8")) - encrypted = xor_crypt(compressed) - return base64.b64encode(encrypted).decode("ascii") - - -def cache_decode(res: Any) -> Any: - current = res - while isinstance(current, str): - try: - decoded_bytes = base64.b64decode(current, validate=True) - current = zlib.decompress(xor_crypt(decoded_bytes)).decode("utf-8") - except Exception: - try: - temp = json.loads(current) - if temp == current: - break - current = temp - except Exception: - break - - def restore_models(item: Any) -> Any: - if isinstance(item, list): - return [restore_models(i) for i in item] - - if isinstance(item, dict): - if item.get("__type__") == "dynamic_obj": - try: - module = importlib.import_module(item["__module__"]) - cls = getattr(module, item["__class__"]) - - cleaned_data = {k: restore_models(v) for k, v in item["data"].items()} - - if hasattr(cls, "model_validate"): - return cls.model_validate(cleaned_data) - return cls(**cleaned_data) - except Exception: - from types import SimpleNamespace - - return SimpleNamespace(**{k: restore_models(v) for k, v in item["data"].items()}) - - return {k: restore_models(v) for k, v in item.items()} - - return item - - return restore_models(current) - - -def format_uptime(delta: timedelta) -> str: - days = delta.days - hours, remainder = divmod(delta.seconds, 3600) - minutes, seconds = divmod(remainder, 60) - - parts = [] - if days: - parts.append(f"{days}d") - if hours: - parts.append(f"{hours}h") - if minutes: - parts.append(f"{minutes}m") - if seconds or not parts: - parts.append(f"{seconds}s") - - return " ".join(parts) - - -def safe_error_message(e: Exception) -> str: - if hasattr(e, "cause") and e.cause: - return str(e.cause) - if e.args: - return str(e.args[0]) - return "Internal Server Error" - - -def shorten_str(data: str, max_len: int = 45) -> str: - return data if len(data) <= max_len else data[: max_len - 3] + "..." +from framex.config import settings def build_swagger_ui_html(openapi_url: str, title: str) -> HTMLResponse: @@ -240,7 +82,6 @@ def build_swagger_ui_html(openapi_url: str, title: str) -> HTMLResponse: box-shadow: var(--fx-shadow); }} - /* 外层三列: tag | description | arrow */ .swagger-ui .opblock-tag {{ display: grid !important; grid-template-columns: 420px minmax(0, 1fr) 28px; @@ -256,7 +97,6 @@ def build_swagger_ui_html(openapi_url: str, title: str) -> HTMLResponse: background: #fafbfc; }} - /* tag 标题 */ .swagger-ui .opblock-tag .nostyle {{ grid-column: 1; min-width: 0; @@ -270,7 +110,6 @@ def build_swagger_ui_html(openapi_url: str, title: str) -> HTMLResponse: color: var(--fx-text) !important; }} - /* description 容器 */ .swagger-ui .opblock-tag small {{ grid-column: 2; display: block !important; @@ -293,7 +132,6 @@ def build_swagger_ui_html(openapi_url: str, title: str) -> HTMLResponse: padding: 0 !important; }} - /* 第一行 description */ .swagger-ui .opblock-tag small .markdown p:first-child {{ margin-bottom: 3px !important; color: var(--fx-text) !important; @@ -308,14 +146,12 @@ def build_swagger_ui_html(openapi_url: str, title: str) -> HTMLResponse: color: var(--fx-text) !important; }} - /* 第二行: 作者、版本、Repo */ .swagger-ui .opblock-tag small .markdown p:last-child {{ color: var(--fx-text-soft) !important; font-size: 12px !important; line-height: 1.4 !important; }} - /* Repo 链接 */ .swagger-ui .opblock-tag small a {{ color: var(--fx-link); text-decoration: none; @@ -327,7 +163,6 @@ def build_swagger_ui_html(openapi_url: str, title: str) -> HTMLResponse: text-decoration: underline; }} - /* 右侧展开箭头 */ .swagger-ui .opblock-tag > button {{ grid-column: 3 !important; justify-self: end !important; @@ -370,7 +205,6 @@ def build_swagger_ui_html(openapi_url: str, title: str) -> HTMLResponse: word-break: break-word; }} - /* 新增: 按钮容器, 放在第一个 tag 上方 */ .swagger-ui .tag-toolbar {{ display: flex; justify-content: flex-end; @@ -495,6 +329,61 @@ def build_swagger_ui_html(openapi_url: str, title: str) -> HTMLResponse: syncToolbarText(); }} + function getTagDescriptionLink(target) {{ + const link = target.closest(".swagger-ui .opblock-tag small a"); + return link instanceof HTMLAnchorElement ? link : null; + }} + + function hydrateLatestReleaseLinks() {{ + const releaseLinks = document.querySelectorAll('.swagger-ui .opblock-tag small a[href*="/docs/plugin-release?plugin="]'); + releaseLinks.forEach((link) => {{ + if (!(link instanceof HTMLAnchorElement) || link.dataset.releaseHydrated === "true") {{ + return; + }} + + link.dataset.releaseHydrated = "true"; + fetch(link.href, {{ credentials: "same-origin" }}) + .then((response) => response.ok ? response.json() : null) + .then((data) => {{ + if (!data || !data.has_update || !data.latest_version) {{ + link.remove(); + return; + }} + + link.textContent = "⬆️ " + data.latest_version; + if (data.repo_url) {{ + link.href = data.repo_url; + }} + }}) + .catch(() => {{ + link.remove(); + }}); + }}); + }} + + document.addEventListener("pointerdown", (event) => {{ + const link = getTagDescriptionLink(event.target); + if (!link) return; + + event.stopPropagation(); + if (link.href.includes("/docs/plugin-config?plugin=")) {{ + event.preventDefault(); + }} + }}, true); + + document.addEventListener("click", (event) => {{ + const link = getTagDescriptionLink(event.target); + if (!link) return; + + event.stopPropagation(); + if (!link.href.includes("/docs/plugin-config?plugin=")) {{ + return; + }} + + event.preventDefault(); + window.open(link.href, "_blank", "noopener,noreferrer"); + }}, true); + window.ui = SwaggerUIBundle({{ url: "{openapi_url}", dom_id: "#swagger-ui", @@ -509,11 +398,13 @@ def build_swagger_ui_html(openapi_url: str, title: str) -> HTMLResponse: layout: "BaseLayout", onComplete: function() {{ insertToolbar(); + hydrateLatestReleaseLinks(); }} }}); const observer = new MutationObserver(() => {{ insertToolbar(); + hydrateLatestReleaseLinks(); }}); observer.observe(document.body, {{ @@ -524,7 +415,21 @@ def build_swagger_ui_html(openapi_url: str, title: str) -> HTMLResponse: """ - ) # roqa + ) + + +def _format_plugin_release_view(plugin_name: str | None = None) -> str: + if not plugin_name: + return "" + plugin_query = quote(plugin_name) + return f" [](/docs/plugin-release?plugin={plugin_query})" + + +def _format_plugin_config_view(plugin_name: str | None = None) -> str: + if not plugin_name or not settings.auth.oauth: + return "" + plugin_query = quote(plugin_name) + return f"[⚙️ View Config](/docs/plugin-config?plugin={plugin_query})" def build_plugin_description( @@ -532,5 +437,10 @@ def build_plugin_description( version: str, description: str, repo: str, + plugin_name: str | None = None, ) -> str: - return f"**{description}**\n\n\n👤 {author} · 🧩 {version} · [🔗 Repo]({repo})" + + latest_release = _format_plugin_release_view(plugin_name) + config_view = _format_plugin_config_view(plugin_name) + config_suffix = f" · {config_view}" if config_view else "" + return f"**{description}**{latest_release}\n\n\n👤 {author} · 🧩 {version} · [🔗 Repo]({repo}){config_suffix}" diff --git a/tests/api/test_proxy.py b/tests/api/test_proxy.py index 1e5b160..6135a47 100644 --- a/tests/api/test_proxy.py +++ b/tests/api/test_proxy.py @@ -1,9 +1,46 @@ +import asyncio + import pytest from fastapi.testclient import TestClient from framex.consts import API_STR +from framex.driver.auth import create_auth_session, create_jwt from framex.utils import cache_decode, cache_encode -from tests.test_plugins import ExchangeModel, SubModel +from tests.test_plugins import ExchangeModel, SubModel, local_exchange_key_value + + +def _set_oauth_session(client: TestClient, monkeypatch) -> None: + from framex.config import settings + + oauth = type( + "OAuthConfig", + (), + { + "provider": "gitlab", + "jwt_secret": "test-secret-test-secret-test-secret", + "jwt_algorithm": "HS256", + "authorization_url": "https://oauth.example.com/authorize", + "client_id": "client", + "call_back_url": "http://test/callback", + }, + )() + monkeypatch.setattr(settings.auth, "oauth", oauth) + + session_id = create_auth_session( + { + "username": "tester", + "oauth_provider": "gitlab", + "oauth_access_token": "oauth-token", + } + ) + token = create_jwt( + { + "username": "tester", + "oauth_provider": "gitlab", + "session_id": session_id, + } + ) + client.cookies.set("framex_token", token) def test_get_proxy_version(client: TestClient): @@ -60,6 +97,150 @@ def test_get_proxy_upload(client: TestClient): } +def test_openapi_tag_description_shows_lazy_release_view(client: TestClient, monkeypatch): + from framex.config import settings + + monkeypatch.setattr(settings.auth, "oauth", None) + data = client.get("/api/v1/openapi.json").json() + + descriptions = [tag.get("description") or "" for tag in data.get("tags", [])] + assert any("/docs/plugin-release?plugin=proxy" in description for description in descriptions) + + +def test_openapi_tag_description_hides_plugin_config_without_auth(client: TestClient, monkeypatch): + from framex.config import settings + + monkeypatch.setattr(settings.auth, "oauth", None) + data = client.get("/api/v1/openapi.json").json() + + descriptions = [tag.get("description") or "" for tag in data.get("tags", [])] + assert all("View Config" not in description for description in descriptions) + assert all("/docs/plugin-config?plugin=proxy" not in description for description in descriptions) + + +def test_get_plugin_release_documentation(client: TestClient, monkeypatch): + monkeypatch.setattr("framex.driver.application.get_latest_repository_version", lambda _: "v9.9.9") + + response = client.get("/docs/plugin-release", params={"plugin": "proxy"}) + + assert response.status_code == 200 + assert response.json() == { + "has_update": True, + "latest_version": "v9.9.9", + "repo_url": "https://github.com/touale/FrameX-kit", + } + + +def test_get_plugin_config_documentation_requires_auth(client: TestClient, monkeypatch): + from framex.config import settings + + monkeypatch.setattr(settings.auth, "oauth", None) + response = client.get("/docs/plugin-config", params={"plugin": "proxy"}) + + assert response.status_code == 403 + assert response.json()["message"] == "Plugin config documentation requires auth" + + +def test_get_plugin_config_documentation_rejects_settings_only_config(client: TestClient, tmp_path, monkeypatch): + from framex.config import settings + + _set_oauth_session(client, monkeypatch) + yaml_path = tmp_path / "proxy-extra.yaml" + yaml_path.write_text( + "token: embedded-secret\nnested:\n client_secret: inner-secret\nname: proxy\n", encoding="utf-8" + ) + monkeypatch.chdir(tmp_path) + monkeypatch.setattr(settings.docs, "embedded_config_file_whitelist", ["proxy-extra.yaml"]) + monkeypatch.setitem( + settings.plugins, + "embedded_demo", + { + "extra_config_path": str(yaml_path), + "enabled": True, + "api_key": "demo-api-key", + }, + ) + + response = client.get("/docs/plugin-config", params={"plugin": "embedded_demo"}) + + assert response.status_code == 403 + assert response.json()["message"] == "Repository access denied: embedded_demo" + + +def test_get_plugin_config_documentation_rejects_settings_only_config_without_whitelist( + client: TestClient, tmp_path, monkeypatch +): + from framex.config import settings + + _set_oauth_session(client, monkeypatch) + yaml_path = tmp_path / "proxy-extra.yaml" + yaml_path.write_text("name: proxy\n", encoding="utf-8") + monkeypatch.chdir(tmp_path) + monkeypatch.setattr(settings.docs, "embedded_config_file_whitelist", []) + monkeypatch.setitem( + settings.plugins, + "embedded_demo_without_whitelist", + {"extra_config_path": str(yaml_path), "enabled": True}, + ) + + response = client.get("/docs/plugin-config", params={"plugin": "embedded_demo_without_whitelist"}) + + assert response.status_code == 403 + assert response.json()["message"] == "Repository access denied: embedded_demo_without_whitelist" + + +def test_get_plugin_config_documentation(client: TestClient, monkeypatch): + monkeypatch.setattr("framex.driver.application.is_private_repository", lambda *_: False) + _set_oauth_session(client, monkeypatch) + response = client.get("/docs/plugin-config", params={"plugin": "proxy"}) + + assert response.status_code == 200 + assert "Plugin Config (TOML)" in response.text + assert "proxy_urls" in response.text + + +def test_get_plugin_config_documentation_requires_repository_access(client: TestClient, monkeypatch): + _set_oauth_session(client, monkeypatch) + monkeypatch.setattr("framex.driver.application.is_private_repository", lambda *_: True) + monkeypatch.setattr("framex.driver.application.can_access_repository", lambda *_: False) + + response = client.get("/docs/plugin-config", params={"plugin": "proxy"}) + + assert response.status_code == 403 + assert response.json()["message"] == "Repository access denied: proxy" + + +def test_get_plugin_config_documentation_checks_public_probe_before_token(client: TestClient, monkeypatch): + called = {"can_access": 0, "is_private": 0} + + def fake_can_access(*_args): + called["can_access"] += 1 + return True + + def fake_is_private(*_args): + called["is_private"] += 1 + return False + + monkeypatch.setattr("framex.driver.application.can_access_repository", fake_can_access) + monkeypatch.setattr("framex.driver.application.is_private_repository", fake_is_private) + _set_oauth_session(client, monkeypatch) + + response = client.get("/docs/plugin-config", params={"plugin": "proxy"}) + + assert response.status_code == 200 + assert called == {"can_access": 0, "is_private": 1} + + +def test_get_plugin_config_documentation_skips_repository_check_for_public_repo(client: TestClient, monkeypatch): + monkeypatch.setattr("framex.driver.application.is_private_repository", lambda *_: False) + _set_oauth_session(client, monkeypatch) + + response = client.get("/docs/plugin-config", params={"plugin": "proxy"}) + + assert response.status_code == 200 + assert "Plugin Config (TOML)" in response.text + + def test_get_proxy_upload_openapi(client: TestClient): data = client.get("/api/v1/openapi.json").json() post = data["paths"]["/proxy/mock/upload"]["post"] @@ -88,6 +269,9 @@ def test_get_proxy_auth_sget(client: TestClient): @pytest.mark.order(2) def test_call_proxy_func(client: TestClient): + from framex.plugin.load import register_proxy_func + + asyncio.run(register_proxy_func(local_exchange_key_value)) func = cache_encode("tests.test_plugins.local_exchange_key_value") data = cache_encode( { diff --git a/tests/conftest.py b/tests/conftest.py index b0c4a06..9086741 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ import framex from framex.config import settings -from tests.mock import mock_get, mock_request +from tests.mock import mock_get, mock_repository_fetch_json, mock_request @pytest.fixture(autouse=True) @@ -58,13 +58,16 @@ def before_record_response(response): @pytest.fixture(scope="session", autouse=True) def test_app() -> Generator: - plugins = framex.load_plugins(str(Path(__file__).parent / "plugins")) - assert len(plugins) == len(["invoker", "export", "alias_model"]) - with ( + patch( + "framex.repository.providers.base.RepositoryVersionProvider.fetch_json", + new=staticmethod(mock_repository_fetch_json), + ), patch("httpx.AsyncClient.get", new=mock_get), patch("httpx.AsyncClient.request", new=mock_request), ): + plugins = framex.load_plugins(str(Path(__file__).parent / "plugins")) + assert len(plugins) == len(["invoker", "export", "alias_model"]) yield framex.run(test_mode=True) # type: ignore[return-value] diff --git a/tests/driver/test_auth.py b/tests/driver/test_auth.py index 73d6974..f886e03 100644 --- a/tests/driver/test_auth.py +++ b/tests/driver/test_auth.py @@ -2,7 +2,7 @@ from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import AsyncMock, Mock, patch -from urllib.parse import urlparse +from urllib.parse import parse_qs, urlparse import jwt import pytest @@ -13,7 +13,14 @@ from framex.config import AuthConfig from framex.consts import AUTH_COOKIE_NAME, DOCS_URL from framex.driver.application import create_fastapi_application -from framex.driver.auth import auth_jwt, authenticate, create_jwt, oauth_callback +from framex.driver.auth import ( + auth_jwt, + authenticate, + create_auth_session, + create_jwt, + decode_auth_token, + oauth_callback, +) JWT_SECRET = uuid.uuid4().hex @@ -24,6 +31,7 @@ def fake_oauth(**overrides): data = dict( # noqa: C408 + provider="gitlab", authorization_url="https://oauth.example.com/authorize", token_url="https://oauth.example.com/token", # noqa: S106 user_info_url="https://oauth.example.com/user", @@ -83,16 +91,8 @@ def test_returns_true_when_token_is_valid(self): mock_oauth.jwt_secret = JWT_SECRET mock_oauth.jwt_algorithm = "HS256" - now = datetime.now(UTC) - token = jwt.encode( - { - "username": "test", - "iat": int(now.timestamp()), - "exp": int((now + timedelta(hours=1)).timestamp()), - }, - JWT_SECRET, - algorithm="HS256", - ) + session_id = create_auth_session({"username": "test"}) + token = create_jwt({"username": "test", "session_id": session_id}) req = Mock(spec=Request) req.cookies.get.return_value = token @@ -183,7 +183,16 @@ async def test_oauth_callback_success(self): res = await oauth_callback(code="abc") assert res.status_code == status.HTTP_302_FOUND assert res.headers["location"] == DOCS_URL - assert f"{AUTH_COOKIE_NAME}=" in res.headers.get("set-cookie", "") + cookie_header = res.headers.get("set-cookie", "") + assert f"{AUTH_COOKIE_NAME}=" in cookie_header + token = cookie_header.split(f"{AUTH_COOKIE_NAME}=", 1)[1].split(";", 1)[0] + payload = jwt.decode(token, JWT_SECRET, algorithms=["HS256"]) + assert payload["oauth_provider"] == "gitlab" + assert payload["session_id"] + assert "oauth_access_token" not in payload + decoded_payload = decode_auth_token(token) + assert decoded_payload is not None + assert decoded_payload["oauth_access_token"] == "oauth-token" # noqa # ========================================================= @@ -203,22 +212,15 @@ def test_docs_redirects_when_not_authenticated(self): location = resp.headers["location"] parsed = urlparse(location) assert parsed.hostname == "oauth.example.com" + assert parse_qs(parsed.query)["scope"] == ["read_user read_api"] def test_docs_accessible_with_valid_jwt(self): with patch("framex.config.settings.auth.oauth", fake_oauth()): app = create_fastapi_application() client = TestClient(app) - now = datetime.now(UTC) - token = jwt.encode( - { - "username": "test", - "iat": int(now.timestamp()), - "exp": int((now + timedelta(hours=1)).timestamp()), - }, - JWT_SECRET, - algorithm="HS256", - ) + session_id = create_auth_session({"username": "test"}) + token = create_jwt({"username": "test", "session_id": session_id}) client.cookies.set(AUTH_COOKIE_NAME, token) resp = client.get("/docs", follow_redirects=False) assert resp.status_code == status.HTTP_200_OK diff --git a/tests/mock.py b/tests/mock.py index 5cbfb04..e53f3d9 100644 --- a/tests/mock.py +++ b/tests/mock.py @@ -87,7 +87,7 @@ async def mock_request(_, method: str, url: str, **kwargs: Any): "params": params, } elif url.endswith("/proxy/remote") and method == "POST": - if headers.get("Authorization") != "i_am_proxy_func_auth_keys": + if headers.get("Authorization") not in {"i_am_local_proxy_auth_keys", "i_am_proxy_func_auth_keys"}: resp.json.return_value = { "status": 401, "message": f"Invalid API Key({headers.get('Authorization')}) for API(/api/v1/proxy/mock/auth/get)", @@ -106,3 +106,14 @@ async def mock_request(_, method: str, url: str, **kwargs: Any): raise AssertionError(f"Unexpected request: {method} {url}") return resp + + +def mock_repository_fetch_json(url: str, headers: dict[str, str] | None = None): + headers = headers or {} + if url.endswith("/releases/latest") and "api.github.com/repos/" in url: + return {"tag_name": "v9.9.9"} + if url.endswith("/releases/permalink/latest") and "/api/v4/projects/" in url: + if headers.get("PRIVATE-TOKEN") == "gitlab-private-token": + return {"tag_name": "v8.8.8"} + return None + raise AssertionError(f"Unexpected repository metadata request: {url}") diff --git a/tests/test_config.py b/tests/test_config.py index ca8dea7..e202251 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,4 @@ -from framex.config import OauthConfig +from framex.config import OauthConfig, RepositoryConfig def test_config(): @@ -60,3 +60,59 @@ def test_proxy_config(): assert not proxy_config.is_white_url("http://localhost:10000", "/health") assert proxy_config.is_white_url("http://localhost:10000", "/echo") assert proxy_config.is_white_url("http://localhost:10001", "/health") + + +def test_repository_auth_config_default_headers(): + cfg = RepositoryConfig() + + assert cfg.auth.github.build_headers() == {} + assert cfg.auth.gitlab.build_headers() == {} + + +def test_repository_auth_config_builds_provider_headers(): + cfg = RepositoryConfig( + auth={ + "github": {"token": "gh-secret"}, + "gitlab": {"token": "gl-secret"}, + } + ) + + assert cfg.auth.github.build_headers() == {"Authorization": "Bearer gh-secret"} + assert cfg.auth.gitlab.build_headers() == {"PRIVATE-TOKEN": "gl-secret"} + + +def test_gitlab_repository_auth_config_uses_matching_endpoint_headers(): + cfg = RepositoryConfig( + auth={ + "gitlab": { + "endpoints": [ + {"host": "gitlab.internal.test", "token": "team-a-token", "path_prefix": "/team-a"}, + {"host": "gitlab.internal.test", "token": "team-b-token", "path_prefix": "/team-b"}, + ] + } + } + ) + + assert cfg.auth.gitlab.build_headers_for_url("gitlab.internal.test", "/team-a/repo") == { + "PRIVATE-TOKEN": "team-a-token" + } + assert cfg.auth.gitlab.build_headers_for_url("gitlab.internal.test", "/team-b/repo") == { + "PRIVATE-TOKEN": "team-b-token" + } + + +def test_gitlab_repository_auth_config_prefers_longest_path_prefix(): + cfg = RepositoryConfig( + auth={ + "gitlab": { + "endpoints": [ + {"host": "gitlab.internal.test", "token": "group-token", "path_prefix": "/team"}, + {"host": "gitlab.internal.test", "token": "project-token", "path_prefix": "/team/project"}, + ] + } + } + ) + + assert cfg.auth.gitlab.build_headers_for_url("gitlab.internal.test", "/team/project/repo") == { + "PRIVATE-TOKEN": "project-token" + } diff --git a/tests/test_utils.py b/tests/test_utils.py index f09d13d..d220b82 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,5 @@ +import html +import importlib import json from datetime import datetime, timedelta from typing import Any @@ -5,13 +7,20 @@ import pytest from pydantic import BaseModel -from framex.config import AuthConfig +from framex.config import AuthConfig, GitLabRepositoryAuthEndpointConfig, settings +from framex.repository import can_access_repository, get_latest_repository_version, has_newer_release_version from framex.utils import ( StreamEnventType, + build_plugin_config_html, + build_plugin_description, cache_decode, cache_encode, + collect_embedded_config_files, format_uptime, make_stream_event, + mask_sensitive_config_data, + mask_sensitive_config_text, + mask_sensitive_embedded_config_content, safe_error_message, ) @@ -181,3 +190,409 @@ def test_safe_error_message_fallback(): e = Exception() e.args = () assert safe_error_message(e) == "Internal Server Error" + + +def test_build_plugin_description_shows_lazy_release_view(): + description = build_plugin_description( + author="tester", + version="v0.3.4", + description="demo plugin", + repo="https://github.com/example/repo", + plugin_name="demo", + ) + + assert "/docs/plugin-release?plugin=demo" in description + + +def test_collect_embedded_config_files_reads_yaml_and_toml(tmp_path): + yaml_path = tmp_path / "demo.yaml" + yaml_path.write_text("name: demo\n", encoding="utf-8") + toml_path = tmp_path / "demo.toml" + toml_path.write_text('name = "demo"\n', encoding="utf-8") + + embedded_files = collect_embedded_config_files( + { + "yaml_path": str(yaml_path), + "nested": {"toml_path": str(toml_path)}, + "ignored": str(tmp_path / "demo.json"), + }, + workspace_root=tmp_path, + whitelist=["*.yaml", "*.toml"], + ) + + assert embedded_files == [ + (str(yaml_path.resolve()), "name: demo\n"), + (str(toml_path.resolve()), 'name = "demo"\n'), + ] + + +def test_collect_embedded_config_files_requires_whitelist(tmp_path): + yaml_path = tmp_path / "demo.yaml" + yaml_path.write_text("name: demo\n", encoding="utf-8") + + embedded_files = collect_embedded_config_files( + {"yaml_path": str(yaml_path)}, + workspace_root=tmp_path, + whitelist=[], + ) + + assert embedded_files == [] + + +def test_collect_embedded_config_files_blocks_outside_workspace(tmp_path): + workspace_root = tmp_path / "workspace" + workspace_root.mkdir() + outside_yaml_path = tmp_path / "outside.yaml" + outside_yaml_path.write_text("name: outside\n", encoding="utf-8") + + embedded_files = collect_embedded_config_files( + {"yaml_path": str(outside_yaml_path)}, + workspace_root=workspace_root, + whitelist=["*.yaml"], + ) + + assert embedded_files == [] + + +def test_mask_sensitive_config_data_masks_nested_values(): + masked = mask_sensitive_config_data( + { + "token": "abcdef123456", + "nested": {"client_secret": "secret-value"}, + "headers": [{"authorization": "Bearer demo-token"}], + "safe": "visible", + } + ) + + assert masked["token"] != "abcdef123456" # noqa + assert masked["nested"]["client_secret"] != "secret-value" # noqa + assert masked["headers"][0]["authorization"] != "Bearer demo-token" + assert masked["safe"] == "visible" + + +def test_mask_sensitive_config_data_masks_auth_rules_values(): + masked = mask_sensitive_config_data( + { + "auth": { + "rules": { + "/api/v1/*": ["Basic YWRtaW46Z3podQ=="], + "/proxy/mock/auth/*": ["i_am_proxy_general_auth_keys"], + } + } + } + ) + + assert masked["auth"]["rules"]["/api/v1/*"][0] != "Basic YWRtaW46Z3podQ==" + assert masked["auth"]["rules"]["/proxy/mock/auth/*"][0] != "i_am_proxy_general_auth_keys" + + +def test_mask_sensitive_config_text_masks_yaml_and_toml_lines(): + content = "\n".join( # noqa + [ + 'token = "abcdef123456"', + "client_secret: secret-value", + 'safe = "visible"', + ] + ) + + masked = mask_sensitive_config_text(content) + + assert 'token = "abcdef123456"' not in masked + assert "client_secret: secret-value" not in masked + assert 'safe = "visible"' in masked + + +def test_mask_sensitive_embedded_config_content_parses_yaml_and_toml(): + yaml_masked = mask_sensitive_embedded_config_content( + "demo.yaml", + "token: abcdef123456\nnested:\n client_secret: secret-value\nsafe: visible\n", + ) + toml_masked = mask_sensitive_embedded_config_content( + "demo.toml", + 'token = "abcdef123456"\n[nested]\nclient_secret = "secret-value"\nsafe = "visible"\n', + ) + + assert "abcdef123456" not in yaml_masked + assert "secret-value" not in yaml_masked + assert "safe: visible" in yaml_masked + assert "abcdef123456" not in toml_masked + assert "secret-value" not in toml_masked + assert 'safe = "visible"' in toml_masked + + +def test_build_plugin_config_html_uses_toml_format(): + response = build_plugin_config_html( + { + "enabled": True, + "name": "demo", + "proxy_urls": ["https://example.com"], + "nested": {"timeout": 30}, + "endpoints": [{"host": "gitlab.example.com", "token": "demo-token"}], + } + ) + + body = html.unescape(response.body.decode()) # type: ignore + assert "Plugin Config (TOML)" in body + assert "enabled = true" in body + assert 'name = "demo"' in body + assert 'proxy_urls = ["https://example.com"]' in body + assert "[nested]" in body + assert "timeout = 30" in body + assert "[[endpoints]]" in body + assert 'host = "gitlab.example.com"' in body + assert 'token = "demo-token"' not in body + + +def test_build_plugin_config_html_hides_restricted_embedded_paths(tmp_path, monkeypatch): + workspace_root = tmp_path / "workspace" + workspace_root.mkdir() + allowed_path = workspace_root / "configs" / "allowed.yaml" + allowed_path.parent.mkdir() + allowed_path.write_text("name: demo\n", encoding="utf-8") + + outside_path = tmp_path / "secret.yaml" + outside_path.write_text("token: secret\n", encoding="utf-8") + + monkeypatch.chdir(workspace_root) + monkeypatch.setattr(settings.docs, "embedded_config_file_whitelist", ["configs/*.yaml"]) + + response = build_plugin_config_html( + { + "allowed_config": str(allowed_path), + "blocked_config": str(outside_path), + } + ) + + body = html.unescape(response.body.decode()) # type: ignore + assert 'allowed_config = "configs/allowed.yaml"' in body + assert 'blocked_config = "[restricted config path]"' in body + assert str(outside_path) not in body + + +def test_has_newer_release_version(): + assert has_newer_release_version("v0.3.4", "v0.3.5") + assert not has_newer_release_version("v0.3.4", "v0.3.4") + assert not has_newer_release_version("v0.3.4", "invalid") + + +def test_get_latest_repository_version_uses_github_auth_token(monkeypatch): + get_latest_repository_version.cache_clear() + monkeypatch.setattr(settings.repository.auth.github, "token", "gh-private-token") + + captured_headers: dict[str, str | None] = {} + + def fake_fetch_json(url: str, headers: dict[str, str] | None = None): + captured_headers["authorization"] = (headers or {}).get("Authorization") + return {"tag_name": "v1.2.3"} + + monkeypatch.setattr( + "framex.repository.providers.base.RepositoryVersionProvider.fetch_json", + staticmethod(fake_fetch_json), + ) + + version = get_latest_repository_version("https://github.com/example/private-repo") + + assert version == "v1.2.3" + assert captured_headers["authorization"] == "Bearer gh-private-token" + get_latest_repository_version.cache_clear() + + +def test_get_latest_repository_version_uses_gitlab_private_token(monkeypatch): + get_latest_repository_version.cache_clear() + monkeypatch.setattr(settings.repository.auth.gitlab, "token", "gitlab-private-token") + monkeypatch.setattr( + settings.repository.auth.gitlab, + "endpoints", + [GitLabRepositoryAuthEndpointConfig(host="gitlab.example.test", token="gitlab-private-token")], # noqa + ) + + captured_headers: dict[str, str | None] = {} + + def fake_fetch_json(url: str, headers: dict[str, str] | None = None): + captured_headers["private_token"] = (headers or {}).get("PRIVATE-TOKEN") + return {"tag_name": "v2.0.0"} + + monkeypatch.setattr( + "framex.repository.providers.base.RepositoryVersionProvider.fetch_json", + staticmethod(fake_fetch_json), + ) + + version = get_latest_repository_version("https://gitlab.com/example/private-repo") + + assert version == "v2.0.0" + assert captured_headers["private_token"] == "gitlab-private-token" # noqa + get_latest_repository_version.cache_clear() + + +def test_get_latest_repository_version_uses_gitlab_endpoint_token(monkeypatch): + get_latest_repository_version.cache_clear() + monkeypatch.setattr( + settings.repository.auth.gitlab, + "endpoints", + [ + GitLabRepositoryAuthEndpointConfig( + host="gitlab.internal.test", + path_prefix="/team-a", + token="team-a-token", # noqa + ) + ], + ) + + captured_headers: dict[str, str | None] = {} + + def fake_fetch_json(url: str, headers: dict[str, str] | None = None): + captured_headers["private_token"] = (headers or {}).get("PRIVATE-TOKEN") + return {"tag_name": "v3.0.0"} + + monkeypatch.setattr( + "framex.repository.providers.base.RepositoryVersionProvider.fetch_json", + staticmethod(fake_fetch_json), + ) + + version = get_latest_repository_version("https://gitlab.internal.test/team-a/private-repo") + + assert version == "v3.0.0" + assert captured_headers["private_token"] == "team-a-token" # noqa + get_latest_repository_version.cache_clear() + + +def test_get_latest_repository_version_resolves_gitlab_project_from_subdirectory_url(monkeypatch): + get_latest_repository_version.cache_clear() + monkeypatch.setattr(settings.repository.auth.gitlab, "token", "gitlab-private-token") + monkeypatch.setattr( + settings.repository.auth.gitlab, + "endpoints", + [GitLabRepositoryAuthEndpointConfig(host="gitlab.example.test", token="gitlab-private-token")], # noqa + ) + + captured_urls: list[str] = [] + + def fake_can_fetch(url: str, headers: dict[str, str] | None = None): + captured_urls.append(url) + return url.endswith("/api/v4/projects/example-group%2Fexample-repo") + + def fake_fetch_json(url: str, headers: dict[str, str] | None = None): + captured_urls.append(url) + return {"tag_name": "v4.0.0"} + + monkeypatch.setattr( + "framex.repository.providers.base.RepositoryVersionProvider.can_fetch", + staticmethod(fake_can_fetch), + ) + monkeypatch.setattr( + "framex.repository.providers.base.RepositoryVersionProvider.fetch_json", + staticmethod(fake_fetch_json), + ) + + version = get_latest_repository_version( + "https://gitlab.example.test/example-group/example-repo/plugins/example-plugin" + ) + + assert version == "v4.0.0" + assert any(url.endswith("/api/v4/projects/example-group%2Fexample-repo") for url in captured_urls) + assert any( + url.endswith("/api/v4/projects/example-group%2Fexample-repo/releases/permalink/latest") + for url in captured_urls + ) + get_latest_repository_version.cache_clear() + + +def test_can_access_repository_resolves_gitlab_project_from_subdirectory_url(monkeypatch): + monkeypatch.setattr( + settings.repository.auth.gitlab, + "endpoints", + [GitLabRepositoryAuthEndpointConfig(host="gitlab.example.test", token="gitlab-private-token")], # noqa + ) + captured: dict[str, object] = {"urls": []} + + def fake_can_fetch( + url: str, + headers: dict[str, str] | None = None, + *, + follow_redirects: bool = True, + ): + captured["urls"].append(url) # type: ignore + captured["authorization"] = (headers or {}).get("Authorization") # type: ignore + return url.endswith("/api/v4/projects/example-group%2Fexample-repo") + + monkeypatch.setattr( + "framex.repository.providers.base.RepositoryVersionProvider.can_fetch", + staticmethod(fake_can_fetch), + ) + + has_access = can_access_repository( + "https://gitlab.example.test/example-group/example-repo/plugins/example-plugin", + "gitlab", + "oauth-token", + ) + + assert has_access is True + assert any( + url.endswith("/api/v4/projects/example-group%2Fexample-repo") + for url in captured["urls"] # type: ignore + ) + assert captured["authorization"] == "Bearer oauth-token" + + +def test_gitlab_public_repository_probe_disables_redirects(monkeypatch): + from urllib.parse import urlparse + + from framex.repository.providers.gitlab import GitLabRepositoryVersionProvider + + captured: dict[str, object] = {} + + def fake_can_fetch( + url: str, + headers: dict[str, str] | None = None, + *, + follow_redirects: bool = True, + ) -> bool: + captured["url"] = url + captured["follow_redirects"] = follow_redirects + return False + + monkeypatch.setattr( + "framex.repository.providers.base.RepositoryVersionProvider.can_fetch", + staticmethod(fake_can_fetch), + ) + + provider = GitLabRepositoryVersionProvider() + provider.is_public_repository( + urlparse("https://gitlab.example.test/example-group/example-repo/plugins/example-plugin") + ) + + assert captured["url"] == ("https://gitlab.example.test/example-group/example-repo/plugins/example-plugin") + assert captured["follow_redirects"] is False + + +def test_repository_fetch_json_follows_redirects(monkeypatch): + import framex.repository.providers.base as base_module + + base_module = importlib.reload(base_module) + captured: dict[str, bool] = {} + + class FakeClient: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def get(self, url: str, follow_redirects: bool = False): # noqa + captured["follow_redirects"] = follow_redirects + response = type("Response", (), {})() + response.status_code = 200 + response.json = lambda: {"tag_name": "v0.0.15"} + return response + + monkeypatch.setattr(base_module.httpx, "Client", FakeClient) + + payload = base_module.RepositoryVersionProvider.fetch_json( + "https://gitlab.internal.test/api/v4/projects/184/releases/permalink/latest" + ) + + assert payload == {"tag_name": "v0.0.15"} + assert captured["follow_redirects"] is True