From 1fc81d81e6c5fe13a087666f4a092fa7da5c6d80 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Wed, 25 Mar 2026 18:15:10 +0800 Subject: [PATCH 1/6] feat: add ApiResolver and context management for plugin APIs --- src/framex/plugin/__init__.py | 76 ++++++++++++++++++++++- src/framex/plugin/base.py | 66 +++++++++++++++++++- tests/test_exception.py | 5 +- tests/test_plugin.py | 112 +++++++++++++++++++++++++++++++++- 4 files changed, 249 insertions(+), 10 deletions(-) diff --git a/src/framex/plugin/__init__.py b/src/framex/plugin/__init__.py index 88b923d..9ae8bc8 100644 --- a/src/framex/plugin/__init__.py +++ b/src/framex/plugin/__init__.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from contextvars import ContextVar from functools import lru_cache from typing import Any, Optional, TypeVar @@ -15,6 +16,66 @@ _manager: PluginManager = PluginManager(silent=settings.test.silent) _current_plugin: ContextVar[Optional["Plugin"]] = ContextVar("_current_plugin", default=None) +_current_api_resolver: ContextVar["ApiResolver | None"] = ContextVar("_current_api_resolver", default=None) +_current_remote_apis: ContextVar[Mapping[str, PluginApi | dict[str, Any]] | None] = ContextVar( + "_current_remote_apis", default=None +) + + +class ApiResolver: + def __init__( + self, + manager: PluginManager | None = None, + api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None, + ) -> None: + self._manager = manager + self._api_registry = api_registry or {} + + @staticmethod + def _coerce_plugin_api(api: PluginApi | dict[str, Any] | None) -> PluginApi | None: + if api is None or isinstance(api, PluginApi): + return api + if isinstance(api, dict): + return PluginApi.model_validate(api) + return None + + def resolve( + self, + api_name: str, + api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None, + ) -> PluginApi | None: + if api_registry is not None and (api := self._coerce_plugin_api(api_registry.get(api_name))): + return api + if self._manager and (api := self._manager.get_api(api_name)): + return api + return self._coerce_plugin_api(self._api_registry.get(api_name)) + + +_api_resolver = ApiResolver(manager=_manager) + + +def get_current_api_resolver() -> ApiResolver | None: + return _current_api_resolver.get() + + +def set_current_api_resolver(resolver: ApiResolver | None) -> Any: + return _current_api_resolver.set(resolver) + + +def reset_current_api_resolver(token: Any) -> None: + _current_api_resolver.reset(token) + + +def get_current_remote_apis() -> Mapping[str, PluginApi | dict[str, Any]] | None: + return _current_remote_apis.get() + + +def set_current_remote_apis(remote_apis: Mapping[str, PluginApi | dict[str, Any]] | None) -> Any: + return _current_remote_apis.set(remote_apis) + + +def reset_current_remote_apis(token: Any) -> None: + _current_remote_apis.reset(token) def get_plugin(plugin_id: str) -> Plugin | None: @@ -40,6 +101,7 @@ def check_plugin_config_exists(plugin_name: str) -> bool: @logger.catch() def init_all_deployments(enable_proxy: bool) -> list[Any]: deployments = [] + all_apis = {**_manager.all_plugin_apis[ApiType.FUNC], **_manager.all_plugin_apis[ApiType.HTTP]} for plugin in get_loaded_plugins(): for dep in plugin.deployments: remote_apis = { @@ -66,6 +128,7 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]: deployment = get_adapter().bind( dep.deployment, remote_apis=remote_apis, + api_registry=all_apis, config=plugin.config, ) @@ -76,15 +139,24 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]: async def call_plugin_api( api_name: str | PluginApi, - interval_apis: dict[str, PluginApi] | None = None, + resolver: ApiResolver | None = None, **kwargs: Any, ) -> Any: + current_remote_apis = get_current_remote_apis() if isinstance(api_name, PluginApi): api: PluginApi | None = api_name elif isinstance(api_name, str): - api = interval_apis.get(api_name) if interval_apis else _manager.get_api(api_name) + active_resolver = resolver or get_current_api_resolver() or _api_resolver + if current_remote_apis is not None: + api = active_resolver._coerce_plugin_api(current_remote_apis.get(api_name)) + else: + api = active_resolver.resolve(api_name, None) use_proxy = False if not api: + if isinstance(api_name, str) and current_remote_apis is not None: + raise RuntimeError( + f"API {api_name} is not declared in current plugin remote_apis; add it to required_remote_apis." + ) if isinstance(api_name, str) and api_name.startswith("/") and settings.server.enable_proxy: api = PluginApi( api=api_name, diff --git a/src/framex/plugin/base.py b/src/framex/plugin/base.py index ba49c59..9a16db5 100644 --- a/src/framex/plugin/base.py +++ b/src/framex/plugin/base.py @@ -1,8 +1,18 @@ +import inspect +from collections.abc import Mapping +from functools import wraps from typing import Any, final from framex.config import settings from framex.log import setup_logger -from framex.plugin import call_plugin_api +from framex.plugin import ( + ApiResolver, + call_plugin_api, + reset_current_api_resolver, + reset_current_remote_apis, + set_current_api_resolver, + set_current_remote_apis, +) from framex.plugin.model import PluginApi @@ -12,6 +22,9 @@ class BasePlugin: def __init__(self, **kwargs: Any) -> None: setup_logger() self.remote_apis: dict[str, PluginApi] = kwargs.get("remote_apis", {}) + self.api_registry: Mapping[str, PluginApi] = kwargs.get("api_registry", {}) + self.api_resolver = ApiResolver(api_registry=self.api_registry) + self._bind_api_resolver_context() if settings.server.use_ray: import asyncio @@ -20,9 +33,58 @@ def __init__(self, **kwargs: Any) -> None: async def on_start(self) -> None: pass + def _bind_api_resolver_context(self) -> None: + for name, func in inspect.getmembers(type(self), predicate=callable): + if not getattr(func, "_on_request", False): + continue + bound = getattr(self, name) + setattr(self, name, self._wrap_with_api_resolver(bound)) + + def _wrap_with_api_resolver(self, func: Any) -> Any: + if inspect.isasyncgenfunction(func): + + @wraps(func) + async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any: + resolver_token = set_current_api_resolver(self.api_resolver) + remote_token = set_current_remote_apis(self.remote_apis) + try: + async for chunk in func(*args, **kwargs): + yield chunk + finally: + reset_current_remote_apis(remote_token) + reset_current_api_resolver(resolver_token) + + return async_gen_wrapper + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + resolver_token = set_current_api_resolver(self.api_resolver) + remote_token = set_current_remote_apis(self.remote_apis) + try: + return await func(*args, **kwargs) + finally: + reset_current_remote_apis(remote_token) + reset_current_api_resolver(resolver_token) + + return async_wrapper + + @wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + resolver_token = set_current_api_resolver(self.api_resolver) + remote_token = set_current_remote_apis(self.remote_apis) + try: + return func(*args, **kwargs) + finally: + reset_current_remote_apis(remote_token) + reset_current_api_resolver(resolver_token) + + return sync_wrapper + @final async def _call_remote_api(self, api_name: str, **kwargs: Any) -> Any: - res = await call_plugin_api(api_name, self.remote_apis, **kwargs) + res = await call_plugin_api(api_name, **kwargs) return self._post_call_remote_api_hook(res) def _post_call_remote_api_hook(self, data: Any) -> Any: diff --git a/tests/test_exception.py b/tests/test_exception.py index c8d0a3e..d2ca1c7 100644 --- a/tests/test_exception.py +++ b/tests/test_exception.py @@ -26,8 +26,5 @@ async def test_call_not_exist_plugin() -> None: assert "API func.call_not_exist_plugin is not found" in str(excinfo.value) with pytest.raises(expected_exception=RuntimeError) as excinfo: - await call_plugin_api( - api_name="/call_not_exist_plugin", - interval_apis={"/call_not_exist_plugin": PluginApi(deployment_name="deployment_name")}, - ) + await call_plugin_api(api_name=PluginApi(api="/call_not_exist_plugin", deployment_name="deployment_name")) assert "No handle or function found for deployment(deployment_name:__call__)" in str(excinfo.value) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 46784ab..925e13f 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -14,13 +14,21 @@ def test_get_plugin(): assert plugin.config.model_dump() == {"id": 123, "name": "test"} +import inspect from unittest.mock import AsyncMock, patch import pytest from pydantic import BaseModel from framex.consts import PROXY_PLUGIN_NAME -from framex.plugin import call_plugin_api +from framex.plugin import ( + ApiResolver, + call_plugin_api, + reset_current_api_resolver, + reset_current_remote_apis, + set_current_api_resolver, + set_current_remote_apis, +) from framex.plugin.model import ApiType, PluginApi @@ -173,12 +181,112 @@ async def test_call_plugin_api_with_interval_apis(self): ): mock_adapter.return_value.call_func = AsyncMock(return_value="interval_result") - result = await call_plugin_api("test_api", interval_apis=interval_apis) + token = set_current_remote_apis(interval_apis) + try: + result = await call_plugin_api("test_api") + finally: + reset_current_remote_apis(token) # Manager get_api should not be called mock_get_api.assert_not_called() assert result == "interval_result" + @pytest.mark.asyncio + async def test_call_plugin_api_with_interval_apis_dict_payload(self): + """Test interval_apis values serialized through Ray can be rehydrated.""" + api = PluginApi(api="test_api", deployment_name="test_deployment") + interval_apis = {"test_api": api.model_dump()} + + with ( + patch("framex.plugin._manager.get_api") as mock_get_api, + patch("framex.plugin.get_adapter") as mock_adapter, + ): + mock_adapter.return_value.call_func = AsyncMock(return_value="interval_dict_result") + + token = set_current_remote_apis(interval_apis) + try: + result = await call_plugin_api("test_api") + finally: + reset_current_remote_apis(token) + + mock_get_api.assert_not_called() + assert result == "interval_dict_result" + + @pytest.mark.asyncio + async def test_call_plugin_api_uses_resolver_when_manager_misses(self): + """Test ApiResolver fallback works when Ray workers lack manager state.""" + func_api = PluginApi(deployment_name="demo_plugin.DemoDeployment", func_name="run", call_type=ApiType.FUNC) + http_api = PluginApi(api="/api/v1/resource_match/resource_detail", deployment_name="resource_match.Detail") + resolver = ApiResolver( + api_registry={ + "demo_plugin.DemoDeployment.run": func_api, + "/api/v1/resource_match/resource_detail": http_api, + } + ) + + with ( + patch("framex.plugin._manager.get_api", return_value=None), + patch("framex.plugin.get_adapter") as mock_adapter, + ): + mock_adapter.return_value.call_func = AsyncMock(side_effect=["func_result", "http_result"]) + + assert await call_plugin_api("demo_plugin.DemoDeployment.run", resolver=resolver) == "func_result" + assert await call_plugin_api("/api/v1/resource_match/resource_detail", resolver=resolver) == "http_result" + + @pytest.mark.asyncio + async def test_call_plugin_api_uses_context_resolver_when_no_explicit_resolver(self): + """Test context-bound resolver works for nested service-style calls.""" + http_api = PluginApi(api="/api/v1/fetch/fetch_company_by_dim_info", deployment_name="fetch.FetchPlugin") + resolver = ApiResolver(api_registry={"/api/v1/fetch/fetch_company_by_dim_info": http_api}) + + with ( + patch("framex.plugin._manager.get_api", return_value=None), + patch("framex.plugin.get_adapter") as mock_adapter, + ): + mock_adapter.return_value.call_func = AsyncMock(return_value="context_result") + token = set_current_api_resolver(resolver) + try: + result = await call_plugin_api("/api/v1/fetch/fetch_company_by_dim_info") + finally: + reset_current_api_resolver(token) + + assert result == "context_result" + assert mock_adapter.return_value.call_func.call_args[0][0] == http_api + + @pytest.mark.asyncio + async def test_call_plugin_api_plugin_context_requires_remote_apis(self): + """Test plugin-context calls do not fall back to global registry outside remote_apis.""" + http_api = PluginApi(api="/api/v1/fetch/fetch_company_by_dim_info", deployment_name="fetch.FetchPlugin") + resolver = ApiResolver(api_registry={"/api/v1/fetch/fetch_company_by_dim_info": http_api}) + + with patch("framex.plugin._manager.get_api", return_value=http_api): + resolver_token = set_current_api_resolver(resolver) + remote_token = set_current_remote_apis({}) + try: + with pytest.raises(RuntimeError, match="not declared in current plugin remote_apis"): + await call_plugin_api("/api/v1/fetch/fetch_company_by_dim_info") + finally: + reset_current_remote_apis(remote_token) + reset_current_api_resolver(resolver_token) + + @pytest.mark.asyncio + async def test_call_plugin_api_context_wrapper_preserves_async_generator(self): + """Test resolver context wrapping keeps stream handlers as async generators.""" + from framex.plugin.base import BasePlugin + from framex.plugin.on import on_request + + class DemoPlugin(BasePlugin): + @on_request("/demo", stream=True) + async def stream_api(self): + yield "a" + yield "b" + + plugin = DemoPlugin(api_registry={}) + stream = plugin.stream_api() + + assert inspect.isasyncgen(stream) + assert [chunk async for chunk in stream] == ["a", "b"] + @pytest.mark.asyncio async def test_call_plugin_api_proxy_creates_correct_plugin_api(self): """Test that proxy fallback creates PluginApi with correct parameters.""" From 940761f735a253315056f00fac63e37e91c82330 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Wed, 25 Mar 2026 19:55:55 +0800 Subject: [PATCH 2/6] test: add Ray integration tests and coverage support --- .coveragerc | 2 +- pytest.ini | 1 + src/framex/__init__.py | 11 +- tests/conftest.py | 21 ++++ tests/ray_plugins/__init__.py | 0 tests/ray_plugins/test_service_caller.py | 32 ++++++ tests/test_use_ray_integration.py | 128 +++++++++++++++++++++++ tools/coverage_support/__init__.py | 0 tools/coverage_support/sitecustomize.py | 8 ++ 9 files changed, 201 insertions(+), 2 deletions(-) create mode 100644 tests/ray_plugins/__init__.py create mode 100644 tests/ray_plugins/test_service_caller.py create mode 100644 tests/test_use_ray_integration.py create mode 100644 tools/coverage_support/__init__.py create mode 100644 tools/coverage_support/sitecustomize.py diff --git a/.coveragerc b/.coveragerc index 406fd24..08ac309 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,7 +4,7 @@ data_file = .coverage/.coverage parallel = true concurrency = thread sigterm = true -source = src +source = src/framex # plugins = coverage_plugins.subprocess diff --git a/pytest.ini b/pytest.ini index 87681a8..261728c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -19,6 +19,7 @@ addopts = -vv -o junit_family=legacy env = + COVERAGE_PROCESS_START=.coveragerc server__use_ray=false server__enable_proxy=true load_builtin_plugins=["echo","proxy","invoker"] diff --git a/src/framex/__init__.py b/src/framex/__init__.py index 0bd733d..aafbc8d 100644 --- a/src/framex/__init__.py +++ b/src/framex/__init__.py @@ -66,6 +66,15 @@ def before_send(event, hint): # noqa ) +def _setup_ray_worker() -> None: # pragma: no cover + settings.server.use_ray = True + + import framex.adapter as adapter_module + + adapter_module._adapter = None + _setup_sentry() + + def run( *, server_host: str | None = None, @@ -146,7 +155,7 @@ def run( "env_vars": { "REVERSION": reversion, }, - "worker_process_setup_hook": _setup_sentry, + "worker_process_setup_hook": _setup_ray_worker, }, ) serve.start( diff --git a/tests/conftest.py b/tests/conftest.py index 567862b..f0150d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,3 +68,24 @@ def client(test_app: FastAPI) -> Generator: def runner(): """Provide a reusable Click CLI runner.""" return CliRunner() + + +@pytest.hookimpl(trylast=True) +def pytest_sessionfinish(session, exitstatus): + try: + from coverage import Coverage + except Exception: + return + + coverage_dir = Path(".coverage") + if not coverage_dir.exists(): + return + + data_files = list(coverage_dir.glob(".coverage.*")) + if not data_files: + return + + cov = Coverage(config_file=True, data_file=str(coverage_dir / ".coverage")) + cov.load() + cov.combine() + cov.save() diff --git a/tests/ray_plugins/__init__.py b/tests/ray_plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/ray_plugins/test_service_caller.py b/tests/ray_plugins/test_service_caller.py new file mode 100644 index 0000000..736930f --- /dev/null +++ b/tests/ray_plugins/test_service_caller.py @@ -0,0 +1,32 @@ +from typing import Any + +from framex.consts import VERSION +from framex.plugin import BasePlugin, PluginMetadata, call_plugin_api, on_register, on_request + + +class EchoService: + async def echo(self, message: str) -> str: + result = await call_plugin_api("/api/v1/echo", message=message) + assert isinstance(result, str) + return result + + +__plugin_meta__ = PluginMetadata( + name="ray_service_caller", + version=VERSION, + description="Ray integration test plugin for nested service plugin calls", + author="tests", + url="https://example.invalid", + required_remote_apis=["/api/v1/echo"], +) + + +@on_register() +class RayServiceCallerPlugin(BasePlugin): + def __init__(self, **kwargs: Any) -> None: + self.service = EchoService() + super().__init__(**kwargs) + + @on_request("/service_echo", methods=["GET"]) + async def service_echo(self, message: str) -> str: + return await self.service.echo(message) diff --git a/tests/test_use_ray_integration.py b/tests/test_use_ray_integration.py new file mode 100644 index 0000000..f8ed161 --- /dev/null +++ b/tests/test_use_ray_integration.py @@ -0,0 +1,128 @@ +import json +import os +import socket +import subprocess +import sys +import tempfile +from contextlib import suppress +from pathlib import Path + +import pytest + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def test_use_ray_nested_service_call_uses_remote_apis() -> None: + pytest.importorskip("ray") + + repo_root = Path(__file__).resolve().parents[1] + plugin_dir = repo_root / "tests" / "ray_plugins" + server_port = _find_free_port() + dashboard_port = _find_free_port() + with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8", suffix=".json", delete=False) as output_file: + output_path = Path(output_file.name) + with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8", suffix=".log", delete=False) as named_log_file: + log_path = Path(named_log_file.name) + + launch_code = """ +import asyncio +import json +import time +from pathlib import Path + +import framex +import ray +from ray import serve +from framex.config import settings +from framex.consts import APP_NAME + + +async def main() -> None: + settings.server.use_ray = True + plugin_dir = Path(r"__PLUGIN_DIR__") + output_path = Path(r"__OUTPUT_PATH__") + framex.run( + use_ray=True, + blocking=False, + server_host="127.0.0.1", + server_port=__SERVER_PORT__, + dashboard_host="127.0.0.1", + dashboard_port=__DASHBOARD_PORT__, + num_cpus=4, + enable_proxy=False, + load_builtin_plugins=["echo"], + load_plugins=[str(plugin_dir)], + ) + handle = serve.get_deployment_handle( + "test_service_caller.RayServiceCallerPlugin", + app_name=APP_NAME, + ) + deadline = time.time() + 90.0 + last_error = None + while time.time() < deadline: + try: + result = await handle.service_echo.remote(message="ray-ok") + output_path.write_text(json.dumps({"result": result}), encoding="utf-8") + return + except Exception as exc: + last_error = exc + await asyncio.sleep(1) + raise RuntimeError(f"Timed out waiting for Ray service call: {last_error!r}") + + +try: + asyncio.run(main()) +finally: + try: + serve.shutdown() + finally: + ray.shutdown() +""" + launch_code = ( + launch_code.replace("__PLUGIN_DIR__", str(plugin_dir)) + .replace("__OUTPUT_PATH__", str(output_path)) + .replace("__SERVER_PORT__", str(server_port)) + .replace("__DASHBOARD_PORT__", str(dashboard_port)) + ) + + env = os.environ.copy() + env["PYTHONPATH"] = os.pathsep.join( + (str(repo_root / "tools" / "coverage_support"), str(repo_root), env.get("PYTHONPATH", "")) + ) + env["SERVER__USE_RAY"] = "true" + + with log_path.open("w", encoding="utf-8") as log_file: + proc = subprocess.Popen( # noqa: S603 + [sys.executable, "-c", launch_code], + cwd=repo_root, + env=env, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + start_new_session=True, + ) + + try: + return_code = proc.wait(timeout=150) + logs = log_path.read_text(encoding="utf-8", errors="replace") + if logs: + sys.stdout.write(logs if logs.endswith("\n") else f"{logs}\n") + if return_code != 0: + pytest.fail(f"Ray integration subprocess failed with code {return_code}.\n{logs}") + + payload = json.loads(output_path.read_text(encoding="utf-8")) + assert payload["result"] == "ray-ok" + finally: + if proc.poll() is None: + proc.terminate() + with suppress(subprocess.TimeoutExpired): + proc.wait(timeout=10) + if proc.poll() is None: + proc.kill() + proc.wait(timeout=10) + output_path.unlink(missing_ok=True) + log_path.unlink(missing_ok=True) diff --git a/tools/coverage_support/__init__.py b/tools/coverage_support/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/coverage_support/sitecustomize.py b/tools/coverage_support/sitecustomize.py new file mode 100644 index 0000000..f8629a1 --- /dev/null +++ b/tools/coverage_support/sitecustomize.py @@ -0,0 +1,8 @@ +import os +from contextlib import suppress + +if os.getenv("COVERAGE_PROCESS_START"): + with suppress(Exception): + import coverage + + coverage.process_startup() From 6c184cb4954f7e543ea6664aa41c6cdf0ff9a397 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Thu, 26 Mar 2026 10:04:46 +0800 Subject: [PATCH 3/6] fix: update use_ray setting and improve exception handling --- src/framex/__init__.py | 1 + tests/conftest.py | 2 +- tests/test_plugin.py | 31 ++++++++++++++----------------- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/framex/__init__.py b/src/framex/__init__.py index aafbc8d..58e9604 100644 --- a/src/framex/__init__.py +++ b/src/framex/__init__.py @@ -96,6 +96,7 @@ def run( dashboard_port = dashboard_port if dashboard_port is not None else settings.server.dashboard_port num_cpus = num_cpus if num_cpus is not None else settings.server.num_cpus use_ray = use_ray if use_ray is not None else settings.server.use_ray + settings.server.use_ray = use_ray enable_proxy = enable_proxy if enable_proxy is not None else settings.server.enable_proxy builtin_plugins = settings.load_builtin_plugins if load_builtin_plugins is None else load_builtin_plugins external_plugins = settings.load_plugins if load_plugins is None else load_plugins diff --git a/tests/conftest.py b/tests/conftest.py index f0150d1..22ce0df 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -74,7 +74,7 @@ def runner(): def pytest_sessionfinish(session, exitstatus): try: from coverage import Coverage - except Exception: + except ImportError: return coverage_dir = Path(".coverage") diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 925e13f..13e91ac 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -1,26 +1,11 @@ -import framex -from framex.consts import VERSION - - -def test_get_plugin(): - # check simple plugin - plugin = framex.get_plugin("export") - assert plugin - assert plugin.version == VERSION - assert plugin.name == "export" - assert plugin.module_name == "tests.plugins.export" - - assert plugin.config - assert plugin.config.model_dump() == {"id": 123, "name": "test"} - - import inspect from unittest.mock import AsyncMock, patch import pytest from pydantic import BaseModel -from framex.consts import PROXY_PLUGIN_NAME +import framex +from framex.consts import PROXY_PLUGIN_NAME, VERSION from framex.plugin import ( ApiResolver, call_plugin_api, @@ -32,6 +17,18 @@ def test_get_plugin(): from framex.plugin.model import ApiType, PluginApi +def test_get_plugin(): + # check simple plugin + plugin = framex.get_plugin("export") + assert plugin + assert plugin.version == VERSION + assert plugin.name == "export" + assert plugin.module_name == "tests.plugins.export" + + assert plugin.config + assert plugin.config.model_dump() == {"id": 123, "name": "test"} + + class SampleModel(BaseModel): """Sample model for testing parameter conversion.""" From d7e3262a474045b0712122af834a3e977a168600 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Thu, 26 Mar 2026 11:14:28 +0800 Subject: [PATCH 4/6] refactor: extract ApiResolver to separate module --- src/framex/plugin/__init__.py | 155 +++++++++++++--------------------- src/framex/plugin/base.py | 6 +- src/framex/plugin/manage.py | 3 + src/framex/plugin/resolver.py | 78 +++++++++++++++++ tests/test_plugin.py | 6 +- 5 files changed, 146 insertions(+), 102 deletions(-) create mode 100644 src/framex/plugin/resolver.py diff --git a/src/framex/plugin/__init__.py b/src/framex/plugin/__init__.py index 9ae8bc8..728afa6 100644 --- a/src/framex/plugin/__init__.py +++ b/src/framex/plugin/__init__.py @@ -1,4 +1,3 @@ -from collections.abc import Mapping from contextvars import ContextVar from functools import lru_cache from typing import Any, Optional, TypeVar @@ -9,73 +8,20 @@ from framex.config import settings from framex.consts import PROXY_PLUGIN_NAME from framex.log import logger -from framex.plugin.manage import PluginManager +from framex.plugin.manage import _manager from framex.plugin.model import Plugin, PluginApi +from framex.plugin.resolver import ( + ApiResolver, + _set_default_api_resolver, + get_current_api_resolver, + get_current_remote_apis, + get_default_api_resolver, +) C = TypeVar("C", bound=BaseModel) -_manager: PluginManager = PluginManager(silent=settings.test.silent) _current_plugin: ContextVar[Optional["Plugin"]] = ContextVar("_current_plugin", default=None) -_current_api_resolver: ContextVar["ApiResolver | None"] = ContextVar("_current_api_resolver", default=None) -_current_remote_apis: ContextVar[Mapping[str, PluginApi | dict[str, Any]] | None] = ContextVar( - "_current_remote_apis", default=None -) - - -class ApiResolver: - def __init__( - self, - manager: PluginManager | None = None, - api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None, - ) -> None: - self._manager = manager - self._api_registry = api_registry or {} - - @staticmethod - def _coerce_plugin_api(api: PluginApi | dict[str, Any] | None) -> PluginApi | None: - if api is None or isinstance(api, PluginApi): - return api - if isinstance(api, dict): - return PluginApi.model_validate(api) - return None - - def resolve( - self, - api_name: str, - api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None, - ) -> PluginApi | None: - if api_registry is not None and (api := self._coerce_plugin_api(api_registry.get(api_name))): - return api - if self._manager and (api := self._manager.get_api(api_name)): - return api - return self._coerce_plugin_api(self._api_registry.get(api_name)) - - -_api_resolver = ApiResolver(manager=_manager) - - -def get_current_api_resolver() -> ApiResolver | None: - return _current_api_resolver.get() - - -def set_current_api_resolver(resolver: ApiResolver | None) -> Any: - return _current_api_resolver.set(resolver) - - -def reset_current_api_resolver(token: Any) -> None: - _current_api_resolver.reset(token) - - -def get_current_remote_apis() -> Mapping[str, PluginApi | dict[str, Any]] | None: - return _current_remote_apis.get() - - -def set_current_remote_apis(remote_apis: Mapping[str, PluginApi | dict[str, Any]] | None) -> Any: - return _current_remote_apis.set(remote_apis) - - -def reset_current_remote_apis(token: Any) -> None: - _current_remote_apis.reset(token) +_set_default_api_resolver(ApiResolver(manager=_manager)) def get_plugin(plugin_id: str) -> Plugin | None: @@ -137,27 +83,26 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]: return deployments -async def call_plugin_api( +def _resolve_plugin_api( api_name: str | PluginApi, resolver: ApiResolver | None = None, - **kwargs: Any, -) -> Any: +) -> tuple[PluginApi, bool]: current_remote_apis = get_current_remote_apis() if isinstance(api_name, PluginApi): - api: PluginApi | None = api_name - elif isinstance(api_name, str): - active_resolver = resolver or get_current_api_resolver() or _api_resolver + return api_name, api_name.call_type == ApiType.PROXY + + active_resolver = resolver or get_current_api_resolver() or get_default_api_resolver() + if current_remote_apis is not None: + api = active_resolver.coerce_plugin_api(current_remote_apis.get(api_name)) + else: + api = active_resolver.resolve(api_name, None) + + if api is None: if current_remote_apis is not None: - api = active_resolver._coerce_plugin_api(current_remote_apis.get(api_name)) - else: - api = active_resolver.resolve(api_name, None) - use_proxy = False - if not api: - if isinstance(api_name, str) and current_remote_apis is not None: raise RuntimeError( f"API {api_name} is not declared in current plugin remote_apis; add it to required_remote_apis." ) - if isinstance(api_name, str) and api_name.startswith("/") and settings.server.enable_proxy: + if api_name.startswith("/") and settings.server.enable_proxy: api = PluginApi( api=api_name, deployment_name=PROXY_PLUGIN_NAME, @@ -166,15 +111,18 @@ async def call_plugin_api( logger.opt(colors=True).warning( f"Api({api_name}) not found, use proxy plugin({PROXY_PLUGIN_NAME}) to transfer!" ) - use_proxy = True else: raise RuntimeError( f"API {api_name} is not found, please check if the plugin is loaded or the API name is correct." ) - if api.call_type == ApiType.PROXY: - use_proxy = True + + return api, api.call_type == ApiType.PROXY + + +def _normalize_plugin_call_kwargs(api: PluginApi, kwargs: dict[str, Any]) -> dict[str, Any]: + normalized_kwargs = dict(kwargs) param_type_map = dict(api.params) - for key, val in kwargs.items(): + for key, val in normalized_kwargs.items(): if ( isinstance(val, dict) and (expected_type := param_type_map.get(key)) @@ -182,26 +130,41 @@ async def call_plugin_api( and issubclass(expected_type, BaseModel) ): try: - kwargs[key] = expected_type(**val) + normalized_kwargs[key] = expected_type(**val) except Exception as e: # pragma: no cover raise RuntimeError(f"Failed to convert '{key}' to {expected_type}") from e - result = await get_adapter().call_func(api, **kwargs) + return normalized_kwargs + + +def _unwrap_plugin_call_result(api_name: str | PluginApi, result: Any, use_proxy: bool) -> Any: if isinstance(result, BaseModel): return result.model_dump(by_alias=True) - if use_proxy: - if not isinstance(result, dict): - return result - if "status" not in result: - raise RuntimeError(f"Proxy API {api_name} returned invalid response: missing 'status' field") - res = result.get("data") - status = result.get("status") - if status not in settings.server.legal_proxy_code: - logger.opt(colors=True).error(f"<>Proxy API {api_name} call illegal: {result}") - raise RuntimeError(f"Proxy API {api_name} returned status {status}") - if res is None: - logger.opt(colors=True).warning(f"API {api_name} returned empty data") - return res - return result + if not use_proxy: + return result + if not isinstance(result, dict): + return result + if "status" not in result: + raise RuntimeError(f"Proxy API {api_name} returned invalid response: missing 'status' field") + + res = result.get("data") + status = result.get("status") + if status not in settings.server.legal_proxy_code: + logger.opt(colors=True).error(f"<>Proxy API {api_name} call illegal: {result}") + raise RuntimeError(f"Proxy API {api_name} returned status {status}") + if res is None: + logger.opt(colors=True).warning(f"API {api_name} returned empty data") + return res + + +async def call_plugin_api( + api_name: str | PluginApi, + resolver: ApiResolver | None = None, + **kwargs: Any, +) -> Any: + api, use_proxy = _resolve_plugin_api(api_name, resolver=resolver) + normalized_kwargs = _normalize_plugin_call_kwargs(api, kwargs) + result = await get_adapter().call_func(api, **normalized_kwargs) + return _unwrap_plugin_call_result(api_name, result, use_proxy) def get_http_plugin_apis() -> list["PluginApi"]: diff --git a/src/framex/plugin/base.py b/src/framex/plugin/base.py index 9a16db5..0042e67 100644 --- a/src/framex/plugin/base.py +++ b/src/framex/plugin/base.py @@ -5,15 +5,15 @@ from framex.config import settings from framex.log import setup_logger -from framex.plugin import ( +from framex.plugin import call_plugin_api +from framex.plugin.model import PluginApi +from framex.plugin.resolver import ( ApiResolver, - call_plugin_api, reset_current_api_resolver, reset_current_remote_apis, set_current_api_resolver, set_current_remote_apis, ) -from framex.plugin.model import PluginApi class BasePlugin: diff --git a/src/framex/plugin/manage.py b/src/framex/plugin/manage.py index 1e4dcaf..185e4a5 100644 --- a/src/framex/plugin/manage.py +++ b/src/framex/plugin/manage.py @@ -15,6 +15,7 @@ from pathlib import Path from types import ModuleType +from framex.config import settings from framex.log import logger from framex.plugin.model import ApiType, Plugin, PluginApi, PluginMetadata from framex.utils import escape_tag, path_to_module_name @@ -269,3 +270,5 @@ def exec_module(self, module: ModuleType) -> None: # Insert a custom plugin module finder into the front of the Python import system to intercept and load plugin modules sys.meta_path.insert(0, PluginFinder()) + +_manager: PluginManager = PluginManager(silent=settings.test.silent) diff --git a/src/framex/plugin/resolver.py b/src/framex/plugin/resolver.py new file mode 100644 index 0000000..1748e09 --- /dev/null +++ b/src/framex/plugin/resolver.py @@ -0,0 +1,78 @@ +from collections.abc import Mapping +from contextvars import ContextVar +from typing import Any, Protocol + +from framex.plugin.model import PluginApi + + +class SupportsApiLookup(Protocol): + def get_api(self, api_name: str) -> PluginApi | None: ... + + +class ApiResolver: + def __init__( + self, + manager: SupportsApiLookup | None = None, + api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None, + ) -> None: + self._manager = manager + self._api_registry = api_registry or {} + + @staticmethod + def coerce_plugin_api(api: PluginApi | dict[str, Any] | None) -> PluginApi | None: + if api is None or isinstance(api, PluginApi): + return api + if isinstance(api, dict): + return PluginApi.model_validate(api) + return None + + def resolve( + self, + api_name: str, + api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None, + ) -> PluginApi | None: + if api_registry is not None and (api := self.coerce_plugin_api(api_registry.get(api_name))): + return api + if self._manager and (api := self._manager.get_api(api_name)): + return api + return self.coerce_plugin_api(self._api_registry.get(api_name)) + + +_current_api_resolver: ContextVar[ApiResolver | None] = ContextVar("_current_api_resolver", default=None) +_current_remote_apis: ContextVar[Mapping[str, PluginApi | dict[str, Any]] | None] = ContextVar( + "_current_remote_apis", default=None +) +_default_api_resolver: ApiResolver | None = None + + +def get_current_api_resolver() -> ApiResolver | None: + return _current_api_resolver.get() + + +def get_default_api_resolver() -> ApiResolver | None: + return _default_api_resolver + + +def _set_default_api_resolver(resolver: ApiResolver | None) -> None: + global _default_api_resolver + _default_api_resolver = resolver + + +def set_current_api_resolver(resolver: ApiResolver | None) -> Any: + return _current_api_resolver.set(resolver) + + +def reset_current_api_resolver(token: Any) -> None: + _current_api_resolver.reset(token) + + +def get_current_remote_apis() -> Mapping[str, PluginApi | dict[str, Any]] | None: + return _current_remote_apis.get() + + +def set_current_remote_apis(remote_apis: Mapping[str, PluginApi | dict[str, Any]] | None) -> Any: + return _current_remote_apis.set(remote_apis) + + +def reset_current_remote_apis(token: Any) -> None: + _current_remote_apis.reset(token) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 13e91ac..b2ff485 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -6,15 +6,15 @@ import framex from framex.consts import PROXY_PLUGIN_NAME, VERSION -from framex.plugin import ( +from framex.plugin import call_plugin_api +from framex.plugin.model import ApiType, PluginApi +from framex.plugin.resolver import ( ApiResolver, - call_plugin_api, reset_current_api_resolver, reset_current_remote_apis, set_current_api_resolver, set_current_remote_apis, ) -from framex.plugin.model import ApiType, PluginApi def test_get_plugin(): From 37212a3ee648e696b505b703724edcc4f7f8b559 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Thu, 26 Mar 2026 11:32:31 +0800 Subject: [PATCH 5/6] fix: enforce default API resolver configuration --- src/framex/plugin/resolver.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/framex/plugin/resolver.py b/src/framex/plugin/resolver.py index 1748e09..e2a2f4b 100644 --- a/src/framex/plugin/resolver.py +++ b/src/framex/plugin/resolver.py @@ -49,7 +49,9 @@ def get_current_api_resolver() -> ApiResolver | None: return _current_api_resolver.get() -def get_default_api_resolver() -> ApiResolver | None: +def get_default_api_resolver() -> ApiResolver: + if _default_api_resolver is None: + raise RuntimeError("Default API resolver is not configured") return _default_api_resolver From 0e297020fee5c8372f1e48268f37c0027dadae1a Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Thu, 26 Mar 2026 11:47:37 +0800 Subject: [PATCH 6/6] update: improve proxy params logging with format_proxy_params --- src/framex/plugin/resolver.py | 2 +- src/framex/plugins/proxy/__init__.py | 4 ++-- src/framex/plugins/proxy/builder.py | 11 +++++++++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/framex/plugin/resolver.py b/src/framex/plugin/resolver.py index e2a2f4b..1e5ddb8 100644 --- a/src/framex/plugin/resolver.py +++ b/src/framex/plugin/resolver.py @@ -55,7 +55,7 @@ def get_default_api_resolver() -> ApiResolver: return _default_api_resolver -def _set_default_api_resolver(resolver: ApiResolver | None) -> None: +def _set_default_api_resolver(resolver: ApiResolver) -> None: global _default_api_resolver _default_api_resolver = resolver diff --git a/src/framex/plugins/proxy/__init__.py b/src/framex/plugins/proxy/__init__.py index 8b9b324..509edf2 100644 --- a/src/framex/plugins/proxy/__init__.py +++ b/src/framex/plugins/proxy/__init__.py @@ -15,7 +15,7 @@ from framex.plugin import BasePlugin, PluginApi, PluginMetadata, on_register from framex.plugin.model import ApiType from framex.plugin.on import on_request -from framex.plugins.proxy.builder import create_pydantic_model, type_map +from framex.plugins.proxy.builder import create_pydantic_model, format_proxy_params, type_map from framex.plugins.proxy.config import ProxyPluginConfig, settings from framex.plugins.proxy.model import ProxyFunc, ProxyFuncHttpBody from framex.utils import cache_decode, cache_encode, shorten_str @@ -235,7 +235,7 @@ def _create_dynamic_method( # Construct dynamic methods async def dynamic_method(**kwargs: Any) -> AsyncGenerator[str, None] | dict[str, Any] | str: - log_info = shorten_str(str(kwargs), 512) + log_info = shorten_str(format_proxy_params(**kwargs), 512) logger.info(f"Calling proxy url: {url} with kwargs: {log_info}") validated = RequestModel(**kwargs) # Type Validation query = {} diff --git a/src/framex/plugins/proxy/builder.py b/src/framex/plugins/proxy/builder.py index b654ebe..29657c0 100644 --- a/src/framex/plugins/proxy/builder.py +++ b/src/framex/plugins/proxy/builder.py @@ -2,6 +2,9 @@ from pydantic import BaseModel, create_model +from framex.plugins.proxy.model import ProxyFuncHttpBody +from framex.utils import cache_decode + _created_models: dict[str, type[BaseModel]] = {} type_map = { @@ -103,3 +106,11 @@ def create_pydantic_model( model: type[BaseModel] = create_model(name, **fields) # type: ignore _created_models[name] = model return model + + +def format_proxy_params(**kwargs: Any) -> str: + if (model := kwargs.get("model")) and isinstance(model, ProxyFuncHttpBody): + func_name = cache_decode(model.func_name) + data = cache_decode(model.data) + return str({"func_name": func_name, "data": data}) + return str(kwargs)