Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ env =
; plugins__proxy__force_stream_apis=[]
sentry__enable=false
test__silent=true
auth__general_auth_keys=["i_am_general_auth_keys"]
auth__auth_urls=["/api/v1/echo"]
asyncio_mode = auto
21 changes: 18 additions & 3 deletions src/framex/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Literal
from typing import Any, Literal, Self

from pydantic import BaseModel
from pydantic import BaseModel, Field, model_validator
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
Expand All @@ -9,6 +9,8 @@
TomlConfigSettingsSource,
)

from framex.utils import is_url_protected


class LogConfig(BaseModel):
simple_log: bool = True
Expand Down Expand Up @@ -51,6 +53,19 @@ class TestConfig(BaseModel):
silent: bool = False


class AuthConfig(BaseModel):
general_auth_keys: list[str] = Field(default_factory=list)
auth_urls: list[str] = Field(default_factory=list)
special_auth_keys: dict[str, list[str]] = Field(default_factory=dict)

@model_validator(mode="after")
def validate_special_auth_urls(self) -> Self:
for special_url in self.special_auth_keys:
if not is_url_protected(special_url, self.auth_urls):
raise ValueError(f"special_auth_keys url '{special_url}' is not covered by any auth_urls rule")
return self


class Settings(BaseSettings):
# Global config
base_ingress_config: dict[str, Any] = {"max_ongoing_requests": 10}
Expand All @@ -66,8 +81,8 @@ class Settings(BaseSettings):
load_builtin_plugins: list[str] = []

test: TestConfig = TestConfig()

sentry: SentryConfig = SentryConfig()
auth: AuthConfig = AuthConfig()

model_config = SettingsConfigDict(
# `.env.prod` takes priority over `.env`
Expand Down
25 changes: 23 additions & 2 deletions src/framex/driver/ingress.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from enum import Enum
from typing import Any

from fastapi import Depends, Response
from fastapi import Depends, HTTPException, Response, status
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from fastapi.security import APIKeyHeader
from pydantic import create_model
from ray.serve.handle import DeploymentHandle
from starlette.routing import Route
Expand All @@ -15,9 +16,10 @@
from framex.driver.decorator import api_ingress
from framex.log import setup_logger
from framex.plugin.model import ApiType, PluginApi
from framex.utils import escape_tag
from framex.utils import escape_tag, get_auth_keys_by_url

app = create_fastapi_application()
api_key_header = APIKeyHeader(name="Authorization", auto_error=True)


@app.get("/health")
Expand All @@ -42,6 +44,7 @@ def __init__(self, deployments: list[DeploymentHandle], plugin_apis: list["Plugi
ApiType.ALL,
]
):
auth_keys = get_auth_keys_by_url(plugin_api.api)
self.register_route(
plugin_api.api,
plugin_api.methods,
Expand All @@ -51,6 +54,7 @@ def __init__(self, deployments: list[DeploymentHandle], plugin_apis: list["Plugi
stream=plugin_api.stream,
direct_output=False,
tags=plugin_api.tags,
auth_keys=auth_keys,
)

def register_route(
Expand All @@ -63,6 +67,7 @@ def register_route(
stream: bool = False,
direct_output: bool = False,
tags: list[str | Enum] | None = None,
auth_keys: list[str] | None = None,
) -> bool:
if tags is None:
tags = ["default"]
Expand Down Expand Up @@ -93,12 +98,28 @@ async def route_handler(response: Response, model: Model = Depends()) -> Any: #
)
return await adapter._acall(c_handle, **model.__dict__) # type: ignore

# Inject auth dependency if needed
dependencies = []
if auth_keys is not None:
logger.debug(f"API({path}) with tags {tags} requires auth.")

def _verify_api_key(api_key: str = Depends(api_key_header)) -> None:
if api_key not in auth_keys:
logger.error(f"Unauthorized access attempt with API Key({api_key}) for API({path})")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid API Key({api_key}) for API({path})",
)

dependencies.append(Depends(_verify_api_key))

app.add_api_route(
path,
route_handler,
methods=methods,
tags=tags,
response_class=StreamingResponse if stream else JSONResponse,
dependencies=dependencies,
)
logger.opt(colors=True).success(
f"Succeeded to register api({methods}): {path} from {handle.deployment_name}"
Expand Down
40 changes: 40 additions & 0 deletions src/framex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,43 @@ def make_stream_event(event_type: StreamEnventType | str, data: str | dict[str,
elif isinstance(data, str):
data = {"content": data}
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"


def is_url_protected(url: str, auth_urls: list[str]) -> bool:
"""Check if a URL is protected by any auth_urls rule."""
for rule in auth_urls:
if rule == url:
return True
if rule.endswith("/*") and url.startswith(rule[:-1]):
return True
return False


def get_auth_keys_by_url(url: str) -> list[str] | None:
from framex.config import settings

auth_config = settings.auth
is_protected = is_url_protected(url, auth_config.auth_urls)

if not is_protected:
return None

if url in auth_config.special_auth_keys:
return auth_config.special_auth_keys[url]

matched_keys = None
matched_len = -1

for rule, keys in auth_config.special_auth_keys.items():
if not rule.endswith("/*"):
continue

prefix = rule[:-1]
if url.startswith(prefix) and len(prefix) > matched_len:
matched_keys = keys
matched_len = len(prefix)

if matched_keys is not None:
return matched_keys

return auth_config.general_auth_keys
18 changes: 17 additions & 1 deletion tests/api/test_echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,27 @@

def test_echo(client: TestClient):
params = {"message": "hello world"}
res = client.get(f"{API_STR}/echo", params=params).json()
headers = {"Authorization": "i_am_general_auth_keys"}
res = client.get(f"{API_STR}/echo", params=params, headers=headers).json()
assert res["status"] == 200
assert res["data"] == params["message"]


def test_echo_with_no_api_key(client: TestClient):
params = {"message": "hello world"}
res = client.get(f"{API_STR}/echo", params=params).json()
assert res["status"] == 403
assert res["message"] == "Not authenticated"


def test_echo_with_error_api_key(client: TestClient):
params = {"message": "hello world"}
headers = {"Authorization": "error_key"}
res = client.get(url=f"{API_STR}/echo", params=params, headers=headers).json()
assert res["status"] == 401
assert res["message"] == "Invalid API Key(error_key) for API(/api/v1/echo)"


def test_echo_model(client: TestClient):
params = {"message": "hello world"}
data = {"id": 1, "name": "原神"}
Expand Down
35 changes: 35 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,42 @@
import pytest
from pydantic_core import ValidationError

from framex.config import AuthConfig


def test_config():
from framex.plugin import get_plugin_config
from framex.plugins.proxy.config import ProxyPluginConfig

cfg = get_plugin_config("proxy", ProxyPluginConfig)
assert isinstance(cfg, ProxyPluginConfig)
assert cfg.proxy_urls is not None


def test_auth_config():
AuthConfig(
general_auth_keys=["abcdefg"],
auth_urls=[
"/api/v1/a/*",
"/api/b/call",
"/api/v1/c/*",
],
special_auth_keys={"/api/v1/a/call": ["0123456789"], "/api/v1/c/*": ["0123456789a", "0123456789b"]},
)

AuthConfig(
general_auth_keys=["abcdefg"],
auth_urls=[
"/api/v1/a/*",
],
special_auth_keys={"/api/v1/a/call": ["0123456789"]},
)

with pytest.raises(ValidationError) as exc_info:
AuthConfig(
auth_urls=["/api/v1/*"],
special_auth_keys={
"/admin/login": ["0123456789"],
},
)
assert "special_auth_keys url '/admin/login'" in str(exc_info.value)
27 changes: 26 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any
from unittest.mock import patch

import pytest
from pydantic import BaseModel

from framex.utils import StreamEnventType, make_stream_event
from framex.config import AuthConfig
from framex.utils import StreamEnventType, get_auth_keys_by_url, make_stream_event


class StreamDataModel(BaseModel):
Expand All @@ -30,3 +32,26 @@ class StreamDataModel(BaseModel):
def test_make_stream_event(event_type: StreamEnventType | str, data: str | dict[str, Any] | BaseModel, result: str):
res = make_stream_event(event_type, data)
assert res == result


def test_get_auth_keys_by_url():
auth = AuthConfig(
general_auth_keys=["g"],
auth_urls=["/api/v1/*", "/api/v2/echo", "/api/v3/*"],
special_auth_keys={
"/api/v1/echo": ["s"],
"/api/v3/echo/*": ["b"],
"/api/v3/echo/hi": ["c"],
},
)

with patch("framex.config.settings.auth", auth):
assert get_auth_keys_by_url("/health") is None
assert get_auth_keys_by_url("/api/v1/user") == ["g"]
assert get_auth_keys_by_url("/api/v1/echo") == ["s"]
assert get_auth_keys_by_url("/api/v1/echo/sub") == ["g"]
assert get_auth_keys_by_url("/api/v2/echo") == ["g"]
assert get_auth_keys_by_url("/api/v2/echo/sub") is None
assert get_auth_keys_by_url("/api/v3/sub") == ["g"]
assert get_auth_keys_by_url("/api/v3/echo/1") == ["b"]
assert get_auth_keys_by_url("/api/v3/echo/hi") == ["c"]
Loading