diff --git a/pytest.ini b/pytest.ini index 29b76f7..0515cf6 100644 --- a/pytest.ini +++ b/pytest.ini @@ -26,4 +26,6 @@ env = ; plugins__proxy__force_stream_apis=[] sentry__enable=false test__silent=true + auth__general_auth_keys=["i_am_general_auth_keys"] + auth__auth_urls=["/api/v1/echo"] asyncio_mode = auto diff --git a/src/framex/config.py b/src/framex/config.py index 7d241f0..6517c3c 100644 --- a/src/framex/config.py +++ b/src/framex/config.py @@ -1,6 +1,6 @@ -from typing import Any, Literal +from typing import Any, Literal, Self -from pydantic import BaseModel +from pydantic import BaseModel, Field, model_validator from pydantic_settings import ( BaseSettings, PydanticBaseSettingsSource, @@ -9,6 +9,8 @@ TomlConfigSettingsSource, ) +from framex.utils import is_url_protected + class LogConfig(BaseModel): simple_log: bool = True @@ -51,6 +53,19 @@ class TestConfig(BaseModel): silent: bool = False +class AuthConfig(BaseModel): + general_auth_keys: list[str] = Field(default_factory=list) + auth_urls: list[str] = Field(default_factory=list) + special_auth_keys: dict[str, list[str]] = Field(default_factory=dict) + + @model_validator(mode="after") + def validate_special_auth_urls(self) -> Self: + for special_url in self.special_auth_keys: + if not is_url_protected(special_url, self.auth_urls): + raise ValueError(f"special_auth_keys url '{special_url}' is not covered by any auth_urls rule") + return self + + class Settings(BaseSettings): # Global config base_ingress_config: dict[str, Any] = {"max_ongoing_requests": 10} @@ -66,8 +81,8 @@ class Settings(BaseSettings): load_builtin_plugins: list[str] = [] test: TestConfig = TestConfig() - sentry: SentryConfig = SentryConfig() + auth: AuthConfig = AuthConfig() model_config = SettingsConfigDict( # `.env.prod` takes priority over `.env` diff --git a/src/framex/driver/ingress.py b/src/framex/driver/ingress.py index addab90..a7610b2 100644 --- a/src/framex/driver/ingress.py +++ b/src/framex/driver/ingress.py @@ -2,9 +2,10 @@ from enum import Enum from typing import Any -from fastapi import Depends, Response +from fastapi import Depends, HTTPException, Response, status from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute +from fastapi.security import APIKeyHeader from pydantic import create_model from ray.serve.handle import DeploymentHandle from starlette.routing import Route @@ -15,9 +16,10 @@ from framex.driver.decorator import api_ingress from framex.log import setup_logger from framex.plugin.model import ApiType, PluginApi -from framex.utils import escape_tag +from framex.utils import escape_tag, get_auth_keys_by_url app = create_fastapi_application() +api_key_header = APIKeyHeader(name="Authorization", auto_error=True) @app.get("/health") @@ -42,6 +44,7 @@ def __init__(self, deployments: list[DeploymentHandle], plugin_apis: list["Plugi ApiType.ALL, ] ): + auth_keys = get_auth_keys_by_url(plugin_api.api) self.register_route( plugin_api.api, plugin_api.methods, @@ -51,6 +54,7 @@ def __init__(self, deployments: list[DeploymentHandle], plugin_apis: list["Plugi stream=plugin_api.stream, direct_output=False, tags=plugin_api.tags, + auth_keys=auth_keys, ) def register_route( @@ -63,6 +67,7 @@ def register_route( stream: bool = False, direct_output: bool = False, tags: list[str | Enum] | None = None, + auth_keys: list[str] | None = None, ) -> bool: if tags is None: tags = ["default"] @@ -93,12 +98,28 @@ async def route_handler(response: Response, model: Model = Depends()) -> Any: # ) return await adapter._acall(c_handle, **model.__dict__) # type: ignore + # Inject auth dependency if needed + dependencies = [] + if auth_keys is not None: + logger.debug(f"API({path}) with tags {tags} requires auth.") + + def _verify_api_key(api_key: str = Depends(api_key_header)) -> None: + if api_key not in auth_keys: + logger.error(f"Unauthorized access attempt with API Key({api_key}) for API({path})") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Invalid API Key({api_key}) for API({path})", + ) + + dependencies.append(Depends(_verify_api_key)) + app.add_api_route( path, route_handler, methods=methods, tags=tags, response_class=StreamingResponse if stream else JSONResponse, + dependencies=dependencies, ) logger.opt(colors=True).success( f"Succeeded to register api({methods}): {path} from {handle.deployment_name}" diff --git a/src/framex/utils.py b/src/framex/utils.py index e67808f..0b4856d 100644 --- a/src/framex/utils.py +++ b/src/framex/utils.py @@ -53,3 +53,43 @@ def make_stream_event(event_type: StreamEnventType | str, data: str | dict[str, elif isinstance(data, str): data = {"content": data} return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + + +def is_url_protected(url: str, auth_urls: list[str]) -> bool: + """Check if a URL is protected by any auth_urls rule.""" + for rule in auth_urls: + if rule == url: + return True + if rule.endswith("/*") and url.startswith(rule[:-1]): + return True + return False + + +def get_auth_keys_by_url(url: str) -> list[str] | None: + from framex.config import settings + + auth_config = settings.auth + is_protected = is_url_protected(url, auth_config.auth_urls) + + if not is_protected: + return None + + if url in auth_config.special_auth_keys: + return auth_config.special_auth_keys[url] + + matched_keys = None + matched_len = -1 + + for rule, keys in auth_config.special_auth_keys.items(): + if not rule.endswith("/*"): + continue + + prefix = rule[:-1] + if url.startswith(prefix) and len(prefix) > matched_len: + matched_keys = keys + matched_len = len(prefix) + + if matched_keys is not None: + return matched_keys + + return auth_config.general_auth_keys diff --git a/tests/api/test_echo.py b/tests/api/test_echo.py index d17565c..e384776 100644 --- a/tests/api/test_echo.py +++ b/tests/api/test_echo.py @@ -7,11 +7,27 @@ def test_echo(client: TestClient): params = {"message": "hello world"} - res = client.get(f"{API_STR}/echo", params=params).json() + headers = {"Authorization": "i_am_general_auth_keys"} + res = client.get(f"{API_STR}/echo", params=params, headers=headers).json() assert res["status"] == 200 assert res["data"] == params["message"] +def test_echo_with_no_api_key(client: TestClient): + params = {"message": "hello world"} + res = client.get(f"{API_STR}/echo", params=params).json() + assert res["status"] == 403 + assert res["message"] == "Not authenticated" + + +def test_echo_with_error_api_key(client: TestClient): + params = {"message": "hello world"} + headers = {"Authorization": "error_key"} + res = client.get(url=f"{API_STR}/echo", params=params, headers=headers).json() + assert res["status"] == 401 + assert res["message"] == "Invalid API Key(error_key) for API(/api/v1/echo)" + + def test_echo_model(client: TestClient): params = {"message": "hello world"} data = {"id": 1, "name": "原神"} diff --git a/tests/test_config.py b/tests/test_config.py index 72b80f4..576786f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,3 +1,9 @@ +import pytest +from pydantic_core import ValidationError + +from framex.config import AuthConfig + + def test_config(): from framex.plugin import get_plugin_config from framex.plugins.proxy.config import ProxyPluginConfig @@ -5,3 +11,32 @@ def test_config(): cfg = get_plugin_config("proxy", ProxyPluginConfig) assert isinstance(cfg, ProxyPluginConfig) assert cfg.proxy_urls is not None + + +def test_auth_config(): + AuthConfig( + general_auth_keys=["abcdefg"], + auth_urls=[ + "/api/v1/a/*", + "/api/b/call", + "/api/v1/c/*", + ], + special_auth_keys={"/api/v1/a/call": ["0123456789"], "/api/v1/c/*": ["0123456789a", "0123456789b"]}, + ) + + AuthConfig( + general_auth_keys=["abcdefg"], + auth_urls=[ + "/api/v1/a/*", + ], + special_auth_keys={"/api/v1/a/call": ["0123456789"]}, + ) + + with pytest.raises(ValidationError) as exc_info: + AuthConfig( + auth_urls=["/api/v1/*"], + special_auth_keys={ + "/admin/login": ["0123456789"], + }, + ) + assert "special_auth_keys url '/admin/login'" in str(exc_info.value) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6a07d8d..4bba9a1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,11 @@ from typing import Any +from unittest.mock import patch import pytest from pydantic import BaseModel -from framex.utils import StreamEnventType, make_stream_event +from framex.config import AuthConfig +from framex.utils import StreamEnventType, get_auth_keys_by_url, make_stream_event class StreamDataModel(BaseModel): @@ -30,3 +32,26 @@ class StreamDataModel(BaseModel): def test_make_stream_event(event_type: StreamEnventType | str, data: str | dict[str, Any] | BaseModel, result: str): res = make_stream_event(event_type, data) assert res == result + + +def test_get_auth_keys_by_url(): + auth = AuthConfig( + general_auth_keys=["g"], + auth_urls=["/api/v1/*", "/api/v2/echo", "/api/v3/*"], + special_auth_keys={ + "/api/v1/echo": ["s"], + "/api/v3/echo/*": ["b"], + "/api/v3/echo/hi": ["c"], + }, + ) + + with patch("framex.config.settings.auth", auth): + assert get_auth_keys_by_url("/health") is None + assert get_auth_keys_by_url("/api/v1/user") == ["g"] + assert get_auth_keys_by_url("/api/v1/echo") == ["s"] + assert get_auth_keys_by_url("/api/v1/echo/sub") == ["g"] + assert get_auth_keys_by_url("/api/v2/echo") == ["g"] + assert get_auth_keys_by_url("/api/v2/echo/sub") is None + assert get_auth_keys_by_url("/api/v3/sub") == ["g"] + assert get_auth_keys_by_url("/api/v3/echo/1") == ["b"] + assert get_auth_keys_by_url("/api/v3/echo/hi") == ["c"]