From 9d66e07da0826476076bc58e4d43ca5fa93eb396 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Fri, 26 Dec 2025 10:40:13 +0800 Subject: [PATCH 1/3] fix: improve proxy function registration and auth configuration --- book/src/advanced_usage/authentication.md | 10 ++-------- book/src/advanced_usage/proxy_function.md | 11 +++-------- src/framex/config.py | 22 +++------------------ src/framex/plugin/load.py | 24 +++++++++++++++++++---- src/framex/plugin/on.py | 19 +++++------------- src/framex/plugins/proxy/__init__.py | 4 ++++ src/framex/plugins/proxy/config.py | 14 ++++++------- tests/test_plugins.py | 2 ++ 8 files changed, 46 insertions(+), 60 deletions(-) diff --git a/book/src/advanced_usage/authentication.md b/book/src/advanced_usage/authentication.md index 797a035..b579eaf 100644 --- a/book/src/advanced_usage/authentication.md +++ b/book/src/advanced_usage/authentication.md @@ -61,10 +61,7 @@ The system supports API-level authentication using **access keys**, configured t ```toml [auth] -rules = { - "/api/v1/echo_model" = ["key-1", "key-2"], - "/api/v2/*" = ["key-3"] -} +rules = {"/api/v1/echo_model" = ["key-1", "key-2"],"/api/v2/*" = ["key-3"]} ``` ### Runtime Behavior @@ -85,8 +82,5 @@ as standard API authentication. ```toml [plugins.proxy.auth] -rules = { - "/api/v1/proxy/remote" = ["proxy-key"], - "/api/v1/echo_model" = ["echo-key"] -} +rules = {"/api/v1/proxy/remote" = ["proxy-key"],"/api/v1/echo_model" = ["echo-key"]} ``` diff --git a/book/src/advanced_usage/proxy_function.md b/book/src/advanced_usage/proxy_function.md index 2b2e219..d893ad0 100644 --- a/book/src/advanced_usage/proxy_function.md +++ b/book/src/advanced_usage/proxy_function.md @@ -46,7 +46,7 @@ class ExamplePlugin(BasePlugin): async def on_start(self) -> None: from demo.some import example_func # Important: Lazy import - register_proxy_func(example_func) + await register_proxy_func(example_func) ``` ______________________________________________________________________ @@ -73,19 +73,14 @@ To enable remote execution, add the following configuration to `config.toml`. [plugins.proxy] proxy_urls = ["http://remotehost:8080"] white_list = ["/api/v1/proxy/remote"] -proxy_functions = { - "http://localhost:8083" = [ - "demo.some.example_func" - ] -} +proxy_functions = {"http://remotehost:8080" = ["demo.some.example_func"]} ``` ### Authentication Configuration ```toml [plugins.proxy.auth] -general_auth_keys = ["7a23713f-c85f-48c0-a349-7a3363f2d30c"] -auth_urls = ["/api/v1/proxy/remote"] +rules = {"/api/v1/proxy/remote" = ["proxy-key"],"/api/v1/openapi.json" = ["openapi-key"]} ``` Once configured, FrameX automatically routes function calls to remote instances when required. diff --git a/src/framex/config.py b/src/framex/config.py index 5cfc2ae..9c5a577 100644 --- a/src/framex/config.py +++ b/src/framex/config.py @@ -1,7 +1,6 @@ -from typing import Any, Literal, Self -from uuid import uuid4 +from typing import Any, Literal -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field from pydantic_settings import ( BaseSettings, PydanticBaseSettingsSource, @@ -10,8 +9,6 @@ TomlConfigSettingsSource, ) -from framex.consts import PROXY_FUNC_HTTP_PATH - class LogConfig(BaseModel): simple_log: bool = True @@ -68,19 +65,6 @@ class TestConfig(BaseModel): class AuthConfig(BaseModel): rules: dict[str, list[str]] = Field(default_factory=dict) - @model_validator(mode="after") - def normalize_and_validate(self) -> Self: - if PROXY_FUNC_HTTP_PATH not in self.rules: - key = str(uuid4()) - self.rules[PROXY_FUNC_HTTP_PATH] = [key] - from framex.log import logger - - logger.warning( - f"No auth key found for {PROXY_FUNC_HTTP_PATH}. A random key {key} was generated. " - "Please configure auth.rules explicitly in production.", - ) - return self - def _is_url_protected(self, url: str) -> bool: for rule in self.rules: if rule == url: @@ -113,7 +97,7 @@ def get_auth_keys(self, url: str) -> list[str] | None: class Settings(BaseSettings): # Global config - base_ingress_config: dict[str, Any] = {"max_ongoing_requests": 10} + base_ingress_config: dict[str, Any] = Field(default_factory=lambda: {"max_ongoing_requests": 10}) server: ServerConfig = Field(default_factory=ServerConfig) log: LogConfig = Field(default_factory=LogConfig) diff --git a/src/framex/plugin/load.py b/src/framex/plugin/load.py index 7258310..6e5e92f 100644 --- a/src/framex/plugin/load.py +++ b/src/framex/plugin/load.py @@ -1,9 +1,11 @@ from collections.abc import Callable +from framex.consts import PROXY_PLUGIN_NAME from framex.log import logger -from framex.plugin.model import Plugin +from framex.plugin.model import ApiType, Plugin, PluginApi +from framex.plugin.on import _PROXY_REGISTRY -from . import _manager, get_loaded_plugins +from . import _manager, call_plugin_api, get_loaded_plugins def load_plugins(*plugin_dir: str) -> set[Plugin]: @@ -37,5 +39,19 @@ def auto_load_plugins(builtin_plugins: list[str], plugins: list[str], enable_pro return builtin_plugin_instances | plugin_instances -def register_proxy_func(_: Callable) -> None: # pragma: no cover - pass +async def register_proxy_func(func: Callable) -> None: + full_func_name = f"{func.__module__}.{func.__name__}" + if full_func_name not in _PROXY_REGISTRY: # pragma: no cover + raise RuntimeError(f"Function {full_func_name} is not registered as a proxy function.") + + api_reg = PluginApi( + deployment_name=PROXY_PLUGIN_NAME, + call_type=ApiType.PROXY, + func_name="register_proxy_function", + ) + await call_plugin_api( + api_reg, + None, + func_name=full_func_name, + func_callable=_PROXY_REGISTRY[full_func_name], + ) diff --git a/src/framex/plugin/on.py b/src/framex/plugin/on.py index 531b670..f0ce108 100644 --- a/src/framex/plugin/on.py +++ b/src/framex/plugin/on.py @@ -98,6 +98,9 @@ def wrapper(func: Callable) -> Callable: return wrapper +_PROXY_REGISTRY: dict[str, Callable] = {} + + def on_proxy() -> Callable: def decorator(func: Callable) -> Callable: from framex.config import settings @@ -124,6 +127,8 @@ async def safe_callable(*args: Any, **kwargs: Any) -> Any: return await raw(*args, **kwargs) return raw(*args, **kwargs) + _PROXY_REGISTRY[full_func_name] = safe_callable + @functools.wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: nonlocal is_registered @@ -131,20 +136,6 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: if args: # pragma: no cover raise TypeError(f"The proxy function '{func.__name__}' only supports keyword arguments.") - if not is_registered: - api_reg = PluginApi( - deployment_name=PROXY_PLUGIN_NAME, - call_type=ApiType.PROXY, - func_name="register_proxy_function", - ) - await call_plugin_api( - api_reg, - None, - func_name=full_func_name, - func_callable=safe_callable, - ) - is_registered = True - api_call = PluginApi( deployment_name=PROXY_PLUGIN_NAME, call_type=ApiType.PROXY, diff --git a/src/framex/plugins/proxy/__init__.py b/src/framex/plugins/proxy/__init__.py index 3a5a7d9..74fda63 100644 --- a/src/framex/plugins/proxy/__init__.py +++ b/src/framex/plugins/proxy/__init__.py @@ -69,6 +69,10 @@ async def _get_openai_docs(self, url: str, docs_path: str = "/api/v1/openapi.jso headers = None async with httpx.AsyncClient(timeout=self.time_out) as client: response = await client.get(f"{url}{docs_path}", headers=headers) + if response.status_code != 200: + logger.error( + f"Failed to get openai docs from {url}, status code: {response.status_code}, response: {response.text}" + ) response.raise_for_status() return cast(dict[str, Any], response.json()) diff --git a/src/framex/plugins/proxy/config.py b/src/framex/plugins/proxy/config.py index 32218fa..73b205b 100644 --- a/src/framex/plugins/proxy/config.py +++ b/src/framex/plugins/proxy/config.py @@ -1,22 +1,22 @@ from typing import Any, Self -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Field, model_validator from framex.config import AuthConfig from framex.plugin import get_plugin_config class ProxyPluginConfig(BaseModel): - proxy_urls: list[str] = [] - force_stream_apis: list[str] = [] + proxy_urls: list[str] = Field(default_factory=list) + force_stream_apis: list[str] = Field(default_factory=list) timeout: int = 600 - ingress_config: dict[str, Any] = {"max_ongoing_requests": 60} + ingress_config: dict[str, Any] = Field(default_factory=lambda: {"max_ongoing_requests": 60}) - white_list: list[str] = [] + white_list: list[str] = Field(default_factory=list) - auth: AuthConfig = AuthConfig() + auth: AuthConfig = Field(default_factory=AuthConfig) - proxy_functions: dict[str, list[str]] = {} + proxy_functions: dict[str, list[str]] = Field(default_factory=dict) def is_white_url(self, url: str) -> bool: """Check if a URL is protected by any auth_urls rule.""" diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 3b6a0a6..eebfea3 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -6,6 +6,7 @@ import pytest from pydantic import BaseModel, ConfigDict, Field +from framex.plugin.load import register_proxy_func from framex.plugin.on import on_proxy @@ -183,6 +184,7 @@ async def remote_exchange_key_value(a_str: str, b_int: int, c_model: ExchangeMod @pytest.mark.order(1) async def test_on_proxy_local_call(): + await register_proxy_func(local_exchange_key_value) res = await local_exchange_key_value( a_str="test", b_int=123, c_model=ExchangeModel(id="id_1", name=100, model=SubModel(id=1, name="sub_name")) ) From 160d16a50d0913f16b8747dbaa4c4a57bb553c12 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Fri, 26 Dec 2025 12:08:06 +0800 Subject: [PATCH 2/3] feat: enhance error handling and proxy response format --- src/framex/driver/ingress.py | 11 ++++++-- src/framex/plugins/proxy/__init__.py | 2 +- src/framex/utils.py | 8 ++++++ tests/api/test_proxy.py | 41 +++++++++++++++++++++++++--- uv.lock | 12 ++++---- 5 files changed, 62 insertions(+), 12 deletions(-) diff --git a/src/framex/driver/ingress.py b/src/framex/driver/ingress.py index e81bcf4..1854583 100644 --- a/src/framex/driver/ingress.py +++ b/src/framex/driver/ingress.py @@ -16,7 +16,7 @@ from framex.driver.decorator import api_ingress from framex.log import setup_logger from framex.plugin.model import ApiType, PluginApi -from framex.utils import escape_tag +from framex.utils import escape_tag, safe_error_message app = create_fastapi_application() api_key_header = APIKeyHeader(name="Authorization", auto_error=True) @@ -100,7 +100,14 @@ async def route_handler(response: Response, model: Model = Depends()) -> Any: # gen, media_type="text/event-stream", ) - return await adapter._acall(c_handle, **model.__dict__) # type: ignore + + try: + return await adapter._acall(c_handle, **model.__dict__) # type: ignore + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=safe_error_message(e), + ) from None # Inject auth dependency if needed dependencies = [] diff --git a/src/framex/plugins/proxy/__init__.py b/src/framex/plugins/proxy/__init__.py index 74fda63..2fb87bb 100644 --- a/src/framex/plugins/proxy/__init__.py +++ b/src/framex/plugins/proxy/__init__.py @@ -166,7 +166,7 @@ async def register_proxy_func_route( params=[("model", ProxyFuncHttpBody)], handle=handle, stream=False, - direct_output=True, + direct_output=False, tags=[__plugin_meta__.name], ) diff --git a/src/framex/utils.py b/src/framex/utils.py index dbcda97..8668345 100644 --- a/src/framex/utils.py +++ b/src/framex/utils.py @@ -148,3 +148,11 @@ def format_uptime(delta: timedelta) -> str: parts.append(f"{seconds}s") return " ".join(parts) + + +def safe_error_message(e: Exception) -> str: + if hasattr(e, "cause") and e.cause: + return str(e.cause) + if e.args: + return str(e.args[0]) + return "Internal Server Error" diff --git a/tests/api/test_proxy.py b/tests/api/test_proxy.py index 41957c8..5ce5049 100644 --- a/tests/api/test_proxy.py +++ b/tests/api/test_proxy.py @@ -59,11 +59,44 @@ def test_call_proxy_func(client: TestClient): body = {"func_name": func, "data": data} headers = {"Authorization": "i_am_local_proxy_auth_keys"} res = client.post("/api/v1/proxy/remote", json=body, headers=headers).json() - res = cache_decode(res) - assert res["a_str"] == "test" - assert res["b_int"] == 123 - model = res["c_model"] + assert res["status"] == 200 + data = cache_decode(res["data"]) + assert data["a_str"] == "test" + assert data["b_int"] == 123 + model = data["c_model"] assert model.id == "id_1" assert model.name == 100 assert model.model.id == 1 assert model.model.name == "sub_name" + + +@pytest.mark.order(2) +def test_call_proxy_func_with_error_func_name(client: TestClient): + func = cache_encode("tests.test_plugins.error_func") + data = cache_encode( + { + "a_str": "test", + "b_int": 123, + "c_model": ExchangeModel(id="id_1", name=100, model=SubModel(id=1, name="sub_name")), + } + ) + body = {"func_name": func, "data": data} + headers = {"Authorization": "i_am_local_proxy_auth_keys"} + res = client.post("/api/v1/proxy/remote", json=body, headers=headers).json() + assert res["status"] == 500 + + +@pytest.mark.order(2) +def test_call_proxy_func_with_error_api_key(client: TestClient): + func = cache_encode("tests.test_plugins.local_exchange_key_value") + data = cache_encode( + { + "a_str": "test", + "b_int": 123, + "c_model": ExchangeModel(id="id_1", name=100, model=SubModel(id=1, name="sub_name")), + } + ) + body = {"func_name": func, "data": data} + headers = {"Authorization": "i_am_error_keys"} + res = client.post("/api/v1/proxy/remote", json=body, headers=headers).json() + assert res["status"] == 401 diff --git a/uv.lock b/uv.lock index 890e00d..2e3ad84 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", @@ -300,7 +301,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -333,7 +334,7 @@ name = "colorful" version = "0.5.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/82/31/109ef4bedeb32b4202e02ddb133162457adc4eb890a9ed9c05c9dd126ed0/colorful-0.5.8.tar.gz", hash = "sha256:bb16502b198be2f1c42ba3c52c703d5f651d826076817185f0294c1a549a7445", size = 209361 } wheels = [ @@ -543,6 +544,7 @@ requires-dist = [ { name = "sentry-sdk", extras = ["fastapi"], specifier = ">=2.33.0" }, { name = "tomli", specifier = ">=2.2.1" }, ] +provides-extras = ["release"] [package.metadata.requires-dev] dev = [ @@ -2185,15 +2187,15 @@ wheels = [ [[package]] name = "sentry-sdk" -version = "2.43.0" +version = "2.48.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b3/18/09875b4323b03ca9025bae7e6539797b27e4fc032998a466b4b9c3d24653/sentry_sdk-2.43.0.tar.gz", hash = "sha256:52ed6e251c5d2c084224d73efee56b007ef5c2d408a4a071270e82131d336e20", size = 368953 } +sdist = { url = "https://files.pythonhosted.org/packages/40/f0/0e9dc590513d5e742d7799e2038df3a05167cba084c6ca4f3cdd75b55164/sentry_sdk-2.48.0.tar.gz", hash = "sha256:5213190977ff7fdff8a58b722fb807f8d5524a80488626ebeda1b5676c0c1473", size = 384828 } wheels = [ - { url = "https://files.pythonhosted.org/packages/69/31/8228fa962f7fd8814d634e4ebece8780e2cdcfbdf0cd2e14d4a6861a7cd5/sentry_sdk-2.43.0-py2.py3-none-any.whl", hash = "sha256:4aacafcf1756ef066d359ae35030881917160ba7f6fc3ae11e0e58b09edc2d5d", size = 400997 }, + { url = "https://files.pythonhosted.org/packages/4d/19/8d77f9992e5cbfcaa9133c3bf63b4fbbb051248802e1e803fed5c552fbb2/sentry_sdk-2.48.0-py2.py3-none-any.whl", hash = "sha256:6b12ac256769d41825d9b7518444e57fa35b5642df4c7c5e322af4d2c8721172", size = 414555 }, ] [package.optional-dependencies] From ccfee690c105d2455a1ea94dc824ae6baf0dde6e Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Fri, 26 Dec 2025 12:25:19 +0800 Subject: [PATCH 3/3] test: enhance error handling and add status code constants --- src/framex/plugins/proxy/__init__.py | 3 ++- tests/test_utils.py | 31 +++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/framex/plugins/proxy/__init__.py b/src/framex/plugins/proxy/__init__.py index 2fb87bb..ada264b 100644 --- a/src/framex/plugins/proxy/__init__.py +++ b/src/framex/plugins/proxy/__init__.py @@ -5,6 +5,7 @@ import httpx from pydantic import BaseModel, create_model +from starlette import status from typing_extensions import override from framex.adapter import get_adapter @@ -69,7 +70,7 @@ async def _get_openai_docs(self, url: str, docs_path: str = "/api/v1/openapi.jso headers = None async with httpx.AsyncClient(timeout=self.time_out) as client: response = await client.get(f"{url}{docs_path}", headers=headers) - if response.status_code != 200: + if response.status_code != status.HTTP_200_OK: # pragma: no cover logger.error( f"Failed to get openai docs from {url}, status code: {response.status_code}, response: {response.text}" ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5075d49..f09d13d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,7 +6,14 @@ from pydantic import BaseModel from framex.config import AuthConfig -from framex.utils import StreamEnventType, cache_decode, cache_encode, format_uptime, make_stream_event +from framex.utils import ( + StreamEnventType, + cache_decode, + cache_encode, + format_uptime, + make_stream_event, + safe_error_message, +) class StreamDataModel(BaseModel): @@ -152,3 +159,25 @@ def test_format_uptime(): # Test only minutes (no seconds) delta = timedelta(minutes=5, seconds=0) assert format_uptime(delta) == "5m" + + +class CauseError(Exception): + def __init__(self, cause: Exception): + self.cause = cause + super().__init__("outer") + + +def test_safe_error_message_with_cause(): + e = CauseError(RuntimeError("inner")) + assert safe_error_message(e) == "inner" + + +def test_safe_error_message_with_args(): + e = RuntimeError("simple error") + assert safe_error_message(e) == "simple error" + + +def test_safe_error_message_fallback(): + e = Exception() + e.args = () + assert safe_error_message(e) == "Internal Server Error"