From b6e0c61af106a4286a787aa8eda38cd4bf3c7601 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Wed, 17 Dec 2025 19:17:41 +0800 Subject: [PATCH 1/5] feat: add API key authentication support --- src/framex/config.py | 31 +++++++++++++++++++++++++++--- src/framex/driver/ingress.py | 23 ++++++++++++++++++++-- src/framex/utils.py | 37 ++++++++++++++++++++++++++++++++++++ tests/test_config.py | 35 ++++++++++++++++++++++++++++++++++ tests/test_utils.py | 22 ++++++++++++++++++++- 5 files changed, 142 insertions(+), 6 deletions(-) diff --git a/src/framex/config.py b/src/framex/config.py index 7d241f0..a00f7e2 100644 --- a/src/framex/config.py +++ b/src/framex/config.py @@ -1,6 +1,6 @@ -from typing import Any, Literal +from typing import Any, Literal, Self -from pydantic import BaseModel +from pydantic import BaseModel, Field, model_validator from pydantic_settings import ( BaseSettings, PydanticBaseSettingsSource, @@ -51,6 +51,31 @@ class TestConfig(BaseModel): silent: bool = False +class AuthConfig(BaseModel): + general_auth_keys: list[str] = Field(default_factory=list) + auth_urls: list[str] = Field(default_factory=list) + special_auth_keys: dict[str, list[str]] = Field(default_factory=dict) + + @model_validator(mode="after") + def validate_special_auth_urls(self) -> Self: + for special_url in self.special_auth_keys: + if not self._is_url_allowed(special_url): + raise ValueError(f"special_auth_keys url '{special_url}' is not covered by any auth_urls rule") + return self + + def _is_url_allowed(self, url: str) -> bool: + for rule in self.auth_urls: + if rule == url: + return True + + if rule.endswith("/*"): + prefix = rule[:-1] + if url.startswith(prefix): + return True + + return False + + class Settings(BaseSettings): # Global config base_ingress_config: dict[str, Any] = {"max_ongoing_requests": 10} @@ -66,8 +91,8 @@ class Settings(BaseSettings): load_builtin_plugins: list[str] = [] test: TestConfig = TestConfig() - sentry: SentryConfig = SentryConfig() + auth: AuthConfig = AuthConfig() model_config = SettingsConfigDict( # `.env.prod` takes priority over `.env` diff --git a/src/framex/driver/ingress.py b/src/framex/driver/ingress.py index addab90..f38e397 100644 --- a/src/framex/driver/ingress.py +++ b/src/framex/driver/ingress.py @@ -2,9 +2,10 @@ from enum import Enum from typing import Any -from fastapi import Depends, Response +from fastapi import Depends, HTTPException, Response, status from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute +from fastapi.security import APIKeyHeader from pydantic import create_model from ray.serve.handle import DeploymentHandle from starlette.routing import Route @@ -15,9 +16,10 @@ from framex.driver.decorator import api_ingress from framex.log import setup_logger from framex.plugin.model import ApiType, PluginApi -from framex.utils import escape_tag +from framex.utils import escape_tag, get_auth_keys_by_url app = create_fastapi_application() +api_key_header = APIKeyHeader(name="Authorization", auto_error=True) @app.get("/health") @@ -42,6 +44,7 @@ def __init__(self, deployments: list[DeploymentHandle], plugin_apis: list["Plugi ApiType.ALL, ] ): + auth_keys = get_auth_keys_by_url(plugin_api.api) self.register_route( plugin_api.api, plugin_api.methods, @@ -51,6 +54,7 @@ def __init__(self, deployments: list[DeploymentHandle], plugin_apis: list["Plugi stream=plugin_api.stream, direct_output=False, tags=plugin_api.tags, + auth_keys=auth_keys, ) def register_route( @@ -63,6 +67,7 @@ def register_route( stream: bool = False, direct_output: bool = False, tags: list[str | Enum] | None = None, + auth_keys: list[str] | None = None, ) -> bool: if tags is None: tags = ["default"] @@ -93,12 +98,26 @@ async def route_handler(response: Response, model: Model = Depends()) -> Any: # ) return await adapter._acall(c_handle, **model.__dict__) # type: ignore + # Inject auth dependency if needed + dependencies = [] + if auth_keys is not None: + + def _verify_api_key(api_key: str = Depends(api_key_header)) -> None: + if not api_key or api_key not in auth_keys: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API Key", + ) + + dependencies.append(Depends(_verify_api_key)) + app.add_api_route( path, route_handler, methods=methods, tags=tags, response_class=StreamingResponse if stream else JSONResponse, + dependencies=dependencies, ) logger.opt(colors=True).success( f"Succeeded to register api({methods}): {path} from {handle.deployment_name}" diff --git a/src/framex/utils.py b/src/framex/utils.py index e67808f..9446192 100644 --- a/src/framex/utils.py +++ b/src/framex/utils.py @@ -53,3 +53,40 @@ def make_stream_event(event_type: StreamEnventType | str, data: str | dict[str, elif isinstance(data, str): data = {"content": data} return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + + +def get_auth_keys_by_url(url: str) -> list[str] | None: + from framex.config import settings + + auth_config = settings.auth + is_protected = False + for rule in auth_config.auth_urls: + if rule == url: + is_protected = True + break + if rule.endswith("/*") and url.startswith(rule[:-1]): + is_protected = True + break + + if not is_protected: + return None + + if url in auth_config.special_auth_keys: + return auth_config.special_auth_keys[url] + + matched_keys = None + matched_len = -1 + + for rule, keys in auth_config.special_auth_keys.items(): + if not rule.endswith("/*"): + continue + + prefix = rule[:-1] + if url.startswith(prefix) and len(prefix) > matched_len: + matched_keys = keys + matched_len = len(prefix) + + if matched_keys is not None: + return matched_keys + + return auth_config.general_auth_keys diff --git a/tests/test_config.py b/tests/test_config.py index 72b80f4..576786f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,3 +1,9 @@ +import pytest +from pydantic_core import ValidationError + +from framex.config import AuthConfig + + def test_config(): from framex.plugin import get_plugin_config from framex.plugins.proxy.config import ProxyPluginConfig @@ -5,3 +11,32 @@ def test_config(): cfg = get_plugin_config("proxy", ProxyPluginConfig) assert isinstance(cfg, ProxyPluginConfig) assert cfg.proxy_urls is not None + + +def test_auth_config(): + AuthConfig( + general_auth_keys=["abcdefg"], + auth_urls=[ + "/api/v1/a/*", + "/api/b/call", + "/api/v1/c/*", + ], + special_auth_keys={"/api/v1/a/call": ["0123456789"], "/api/v1/c/*": ["0123456789a", "0123456789b"]}, + ) + + AuthConfig( + general_auth_keys=["abcdefg"], + auth_urls=[ + "/api/v1/a/*", + ], + special_auth_keys={"/api/v1/a/call": ["0123456789"]}, + ) + + with pytest.raises(ValidationError) as exc_info: + AuthConfig( + auth_urls=["/api/v1/*"], + special_auth_keys={ + "/admin/login": ["0123456789"], + }, + ) + assert "special_auth_keys url '/admin/login'" in str(exc_info.value) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6a07d8d..2374342 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,11 @@ from typing import Any +from unittest.mock import patch import pytest from pydantic import BaseModel -from framex.utils import StreamEnventType, make_stream_event +from framex.config import AuthConfig +from framex.utils import StreamEnventType, get_auth_keys_by_url, make_stream_event class StreamDataModel(BaseModel): @@ -30,3 +32,21 @@ class StreamDataModel(BaseModel): def test_make_stream_event(event_type: StreamEnventType | str, data: str | dict[str, Any] | BaseModel, result: str): res = make_stream_event(event_type, data) assert res == result + + +def test_get_auth_keys_by_url(): + auth = AuthConfig( + general_auth_keys=["g"], + auth_urls=["/api/v1/*", "/api/v2/echo"], + special_auth_keys={ + "/api/v1/echo": ["s"], + }, + ) + + with patch("framex.config.settings.auth", auth): + assert get_auth_keys_by_url("/health") is None + assert get_auth_keys_by_url("/api/v1/user") == ["g"] + assert get_auth_keys_by_url("/api/v1/echo") == ["s"] + assert get_auth_keys_by_url("/api/v1/echo/sub") == ["g"] + assert get_auth_keys_by_url("/api/v2/echo") == ["g"] + assert get_auth_keys_by_url("/api/v2/echo/sub") is None From 55cb1b62771b0bdfbab7b69ffdd8f99d5f0cb21d Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Wed, 17 Dec 2025 19:29:28 +0800 Subject: [PATCH 2/5] test: add auth configuration and test for echo API --- pytest.ini | 2 ++ src/framex/driver/ingress.py | 1 + tests/api/test_echo.py | 3 ++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index 29b76f7..0515cf6 100644 --- a/pytest.ini +++ b/pytest.ini @@ -26,4 +26,6 @@ env = ; plugins__proxy__force_stream_apis=[] sentry__enable=false test__silent=true + auth__general_auth_keys=["i_am_general_auth_keys"] + auth__auth_urls=["/api/v1/echo"] asyncio_mode = auto diff --git a/src/framex/driver/ingress.py b/src/framex/driver/ingress.py index f38e397..7c5e9a6 100644 --- a/src/framex/driver/ingress.py +++ b/src/framex/driver/ingress.py @@ -101,6 +101,7 @@ async def route_handler(response: Response, model: Model = Depends()) -> Any: # # Inject auth dependency if needed dependencies = [] if auth_keys is not None: + logger.debug(f"API({path}) with tags {tags} requires auth.") def _verify_api_key(api_key: str = Depends(api_key_header)) -> None: if not api_key or api_key not in auth_keys: diff --git a/tests/api/test_echo.py b/tests/api/test_echo.py index d17565c..635617b 100644 --- a/tests/api/test_echo.py +++ b/tests/api/test_echo.py @@ -7,7 +7,8 @@ def test_echo(client: TestClient): params = {"message": "hello world"} - res = client.get(f"{API_STR}/echo", params=params).json() + headers = {"Authorization": "i_am_general_auth_keys"} + res = client.get(f"{API_STR}/echo", params=params, headers=headers).json() assert res["status"] == 200 assert res["data"] == params["message"] From b371beee2bc084142a2fe2c6d5e94c6f5948d324 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Thu, 18 Dec 2025 10:16:07 +0800 Subject: [PATCH 3/5] test: Enhance api_key test cases --- src/framex/driver/ingress.py | 3 ++- src/framex/utils.py | 4 ++-- tests/api/test_echo.py | 15 +++++++++++++++ tests/test_utils.py | 9 +++++++-- 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/framex/driver/ingress.py b/src/framex/driver/ingress.py index 7c5e9a6..82b9de3 100644 --- a/src/framex/driver/ingress.py +++ b/src/framex/driver/ingress.py @@ -105,9 +105,10 @@ async def route_handler(response: Response, model: Model = Depends()) -> Any: # def _verify_api_key(api_key: str = Depends(api_key_header)) -> None: if not api_key or api_key not in auth_keys: + logger.error(f"Unauthorized access attempt with API Key({api_key}) for API({path})") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API Key", + detail=f"Invalid API Key({api_key}) for API({path})", ) dependencies.append(Depends(_verify_api_key)) diff --git a/src/framex/utils.py b/src/framex/utils.py index 9446192..84e3700 100644 --- a/src/framex/utils.py +++ b/src/framex/utils.py @@ -8,6 +8,8 @@ from pydantic import BaseModel +from framex.config import settings + def plugin_to_deployment_name(plugin_name: str, obj_name: str) -> str: return f"{plugin_name}.{obj_name}" @@ -56,8 +58,6 @@ def make_stream_event(event_type: StreamEnventType | str, data: str | dict[str, def get_auth_keys_by_url(url: str) -> list[str] | None: - from framex.config import settings - auth_config = settings.auth is_protected = False for rule in auth_config.auth_urls: diff --git a/tests/api/test_echo.py b/tests/api/test_echo.py index 635617b..e384776 100644 --- a/tests/api/test_echo.py +++ b/tests/api/test_echo.py @@ -13,6 +13,21 @@ def test_echo(client: TestClient): assert res["data"] == params["message"] +def test_echo_with_no_api_key(client: TestClient): + params = {"message": "hello world"} + res = client.get(f"{API_STR}/echo", params=params).json() + assert res["status"] == 403 + assert res["message"] == "Not authenticated" + + +def test_echo_with_error_api_key(client: TestClient): + params = {"message": "hello world"} + headers = {"Authorization": "error_key"} + res = client.get(url=f"{API_STR}/echo", params=params, headers=headers).json() + assert res["status"] == 401 + assert res["message"] == "Invalid API Key(error_key) for API(/api/v1/echo)" + + def test_echo_model(client: TestClient): params = {"message": "hello world"} data = {"id": 1, "name": "原神"} diff --git a/tests/test_utils.py b/tests/test_utils.py index 2374342..87951cb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -37,16 +37,21 @@ def test_make_stream_event(event_type: StreamEnventType | str, data: str | dict[ def test_get_auth_keys_by_url(): auth = AuthConfig( general_auth_keys=["g"], - auth_urls=["/api/v1/*", "/api/v2/echo"], + auth_urls=["/api/v1/*", "/api/v2/echo", "/api/v3/*"], special_auth_keys={ "/api/v1/echo": ["s"], + "/api/v3/echo/*": ["b"], + "/api/v3/echo/hi": ["c"], }, ) - with patch("framex.config.settings.auth", auth): + with patch("framex.utils.settings.auth", auth): assert get_auth_keys_by_url("/health") is None assert get_auth_keys_by_url("/api/v1/user") == ["g"] assert get_auth_keys_by_url("/api/v1/echo") == ["s"] assert get_auth_keys_by_url("/api/v1/echo/sub") == ["g"] assert get_auth_keys_by_url("/api/v2/echo") == ["g"] assert get_auth_keys_by_url("/api/v2/echo/sub") is None + assert get_auth_keys_by_url("/api/v3/sub") == ["g"] + assert get_auth_keys_by_url("/api/v3/echo/1") == ["b"] + assert get_auth_keys_by_url("/api/v3/echo/hi") == ["c"] From b008e5c9ee7f7c352e493c6041df7e39d07b0c56 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Thu, 18 Dec 2025 10:23:23 +0800 Subject: [PATCH 4/5] perf: Reduce code duplication with get_auth_keys_by_url --- src/framex/config.py | 16 +++------------- src/framex/utils.py | 23 +++++++++++++---------- tests/test_utils.py | 2 +- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/src/framex/config.py b/src/framex/config.py index a00f7e2..6517c3c 100644 --- a/src/framex/config.py +++ b/src/framex/config.py @@ -9,6 +9,8 @@ TomlConfigSettingsSource, ) +from framex.utils import is_url_protected + class LogConfig(BaseModel): simple_log: bool = True @@ -59,22 +61,10 @@ class AuthConfig(BaseModel): @model_validator(mode="after") def validate_special_auth_urls(self) -> Self: for special_url in self.special_auth_keys: - if not self._is_url_allowed(special_url): + if not is_url_protected(special_url, self.auth_urls): raise ValueError(f"special_auth_keys url '{special_url}' is not covered by any auth_urls rule") return self - def _is_url_allowed(self, url: str) -> bool: - for rule in self.auth_urls: - if rule == url: - return True - - if rule.endswith("/*"): - prefix = rule[:-1] - if url.startswith(prefix): - return True - - return False - class Settings(BaseSettings): # Global config diff --git a/src/framex/utils.py b/src/framex/utils.py index 84e3700..0b4856d 100644 --- a/src/framex/utils.py +++ b/src/framex/utils.py @@ -8,8 +8,6 @@ from pydantic import BaseModel -from framex.config import settings - def plugin_to_deployment_name(plugin_name: str, obj_name: str) -> str: return f"{plugin_name}.{obj_name}" @@ -57,16 +55,21 @@ def make_stream_event(event_type: StreamEnventType | str, data: str | dict[str, return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" -def get_auth_keys_by_url(url: str) -> list[str] | None: - auth_config = settings.auth - is_protected = False - for rule in auth_config.auth_urls: +def is_url_protected(url: str, auth_urls: list[str]) -> bool: + """Check if a URL is protected by any auth_urls rule.""" + for rule in auth_urls: if rule == url: - is_protected = True - break + return True if rule.endswith("/*") and url.startswith(rule[:-1]): - is_protected = True - break + return True + return False + + +def get_auth_keys_by_url(url: str) -> list[str] | None: + from framex.config import settings + + auth_config = settings.auth + is_protected = is_url_protected(url, auth_config.auth_urls) if not is_protected: return None diff --git a/tests/test_utils.py b/tests/test_utils.py index 87951cb..4bba9a1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -45,7 +45,7 @@ def test_get_auth_keys_by_url(): }, ) - with patch("framex.utils.settings.auth", auth): + with patch("framex.config.settings.auth", auth): assert get_auth_keys_by_url("/health") is None assert get_auth_keys_by_url("/api/v1/user") == ["g"] assert get_auth_keys_by_url("/api/v1/echo") == ["s"] From 57df12d888628483a99e50e7be45375b3691cef0 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Thu, 18 Dec 2025 10:24:30 +0800 Subject: [PATCH 5/5] perf: improve API key authentication validation --- src/framex/driver/ingress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/framex/driver/ingress.py b/src/framex/driver/ingress.py index 82b9de3..a7610b2 100644 --- a/src/framex/driver/ingress.py +++ b/src/framex/driver/ingress.py @@ -104,7 +104,7 @@ async def route_handler(response: Response, model: Model = Depends()) -> Any: # logger.debug(f"API({path}) with tags {tags} requires auth.") def _verify_api_key(api_key: str = Depends(api_key_header)) -> None: - if not api_key or api_key not in auth_keys: + if api_key not in auth_keys: logger.error(f"Unauthorized access attempt with API Key({api_key}) for API({path})") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED,