Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions book/src/advanced_usage/authentication.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ The system supports API-level authentication using **access keys**, configured t

```toml
[auth]
rules = {
"/api/v1/echo_model" = ["key-1", "key-2"],
"/api/v2/*" = ["key-3"]
}
rules = {"/api/v1/echo_model" = ["key-1", "key-2"],"/api/v2/*" = ["key-3"]}
```

### Runtime Behavior
Expand All @@ -85,8 +82,5 @@ as standard API authentication.

```toml
[plugins.proxy.auth]
rules = {
"/api/v1/proxy/remote" = ["proxy-key"],
"/api/v1/echo_model" = ["echo-key"]
}
rules = {"/api/v1/proxy/remote" = ["proxy-key"],"/api/v1/echo_model" = ["echo-key"]}
```
11 changes: 3 additions & 8 deletions book/src/advanced_usage/proxy_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ExamplePlugin(BasePlugin):
async def on_start(self) -> None:
from demo.some import example_func # Important: Lazy import

register_proxy_func(example_func)
await register_proxy_func(example_func)
```

______________________________________________________________________
Expand All @@ -73,19 +73,14 @@ To enable remote execution, add the following configuration to `config.toml`.
[plugins.proxy]
proxy_urls = ["http://remotehost:8080"]
white_list = ["/api/v1/proxy/remote"]
proxy_functions = {
"http://localhost:8083" = [
"demo.some.example_func"
]
}
proxy_functions = {"http://remotehost:8080" = ["demo.some.example_func"]}
```

### Authentication Configuration

```toml
[plugins.proxy.auth]
general_auth_keys = ["7a23713f-c85f-48c0-a349-7a3363f2d30c"]
auth_urls = ["/api/v1/proxy/remote"]
rules = {"/api/v1/proxy/remote" = ["proxy-key"],"/api/v1/openapi.json" = ["openapi-key"]}
```

Once configured, FrameX automatically routes function calls to remote instances when required.
Expand Down
22 changes: 3 additions & 19 deletions src/framex/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Literal, Self
from uuid import uuid4
from typing import Any, Literal

from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
Expand All @@ -10,8 +9,6 @@
TomlConfigSettingsSource,
)

from framex.consts import PROXY_FUNC_HTTP_PATH


class LogConfig(BaseModel):
simple_log: bool = True
Expand Down Expand Up @@ -68,19 +65,6 @@ class TestConfig(BaseModel):
class AuthConfig(BaseModel):
rules: dict[str, list[str]] = Field(default_factory=dict)

@model_validator(mode="after")
def normalize_and_validate(self) -> Self:
if PROXY_FUNC_HTTP_PATH not in self.rules:
key = str(uuid4())
self.rules[PROXY_FUNC_HTTP_PATH] = [key]
from framex.log import logger

logger.warning(
f"No auth key found for {PROXY_FUNC_HTTP_PATH}. A random key {key} was generated. "
"Please configure auth.rules explicitly in production.",
)
return self

def _is_url_protected(self, url: str) -> bool:
for rule in self.rules:
if rule == url:
Expand Down Expand Up @@ -113,7 +97,7 @@ def get_auth_keys(self, url: str) -> list[str] | None:

class Settings(BaseSettings):
# Global config
base_ingress_config: dict[str, Any] = {"max_ongoing_requests": 10}
base_ingress_config: dict[str, Any] = Field(default_factory=lambda: {"max_ongoing_requests": 10})

server: ServerConfig = Field(default_factory=ServerConfig)
log: LogConfig = Field(default_factory=LogConfig)
Expand Down
11 changes: 9 additions & 2 deletions src/framex/driver/ingress.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from framex.driver.decorator import api_ingress
from framex.log import setup_logger
from framex.plugin.model import ApiType, PluginApi
from framex.utils import escape_tag
from framex.utils import escape_tag, safe_error_message

app = create_fastapi_application()
api_key_header = APIKeyHeader(name="Authorization", auto_error=True)
Expand Down Expand Up @@ -100,7 +100,14 @@ async def route_handler(response: Response, model: Model = Depends()) -> Any: #
gen,
media_type="text/event-stream",
)
return await adapter._acall(c_handle, **model.__dict__) # type: ignore

try:
return await adapter._acall(c_handle, **model.__dict__) # type: ignore
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=safe_error_message(e),
) from None

# Inject auth dependency if needed
dependencies = []
Expand Down
24 changes: 20 additions & 4 deletions src/framex/plugin/load.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections.abc import Callable

from framex.consts import PROXY_PLUGIN_NAME
from framex.log import logger
from framex.plugin.model import Plugin
from framex.plugin.model import ApiType, Plugin, PluginApi
from framex.plugin.on import _PROXY_REGISTRY

from . import _manager, get_loaded_plugins
from . import _manager, call_plugin_api, get_loaded_plugins


def load_plugins(*plugin_dir: str) -> set[Plugin]:
Expand Down Expand Up @@ -37,5 +39,19 @@ def auto_load_plugins(builtin_plugins: list[str], plugins: list[str], enable_pro
return builtin_plugin_instances | plugin_instances


def register_proxy_func(_: Callable) -> None: # pragma: no cover
pass
async def register_proxy_func(func: Callable) -> None:
full_func_name = f"{func.__module__}.{func.__name__}"
if full_func_name not in _PROXY_REGISTRY: # pragma: no cover
raise RuntimeError(f"Function {full_func_name} is not registered as a proxy function.")

api_reg = PluginApi(
deployment_name=PROXY_PLUGIN_NAME,
call_type=ApiType.PROXY,
func_name="register_proxy_function",
)
await call_plugin_api(
api_reg,
None,
func_name=full_func_name,
func_callable=_PROXY_REGISTRY[full_func_name],
)
19 changes: 5 additions & 14 deletions src/framex/plugin/on.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def wrapper(func: Callable) -> Callable:
return wrapper


_PROXY_REGISTRY: dict[str, Callable] = {}


def on_proxy() -> Callable:
def decorator(func: Callable) -> Callable:
from framex.config import settings
Expand All @@ -124,27 +127,15 @@ async def safe_callable(*args: Any, **kwargs: Any) -> Any:
return await raw(*args, **kwargs)
return raw(*args, **kwargs)

_PROXY_REGISTRY[full_func_name] = safe_callable

@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal is_registered

if args: # pragma: no cover
raise TypeError(f"The proxy function '{func.__name__}' only supports keyword arguments.")

if not is_registered:
api_reg = PluginApi(
deployment_name=PROXY_PLUGIN_NAME,
call_type=ApiType.PROXY,
func_name="register_proxy_function",
)
await call_plugin_api(
api_reg,
None,
func_name=full_func_name,
func_callable=safe_callable,
)
is_registered = True

api_call = PluginApi(
deployment_name=PROXY_PLUGIN_NAME,
call_type=ApiType.PROXY,
Expand Down
7 changes: 6 additions & 1 deletion src/framex/plugins/proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import httpx
from pydantic import BaseModel, create_model
from starlette import status
from typing_extensions import override

from framex.adapter import get_adapter
Expand Down Expand Up @@ -69,6 +70,10 @@ async def _get_openai_docs(self, url: str, docs_path: str = "/api/v1/openapi.jso
headers = None
async with httpx.AsyncClient(timeout=self.time_out) as client:
response = await client.get(f"{url}{docs_path}", headers=headers)
if response.status_code != status.HTTP_200_OK: # pragma: no cover
logger.error(
f"Failed to get openai docs from {url}, status code: {response.status_code}, response: {response.text}"
)
Comment on lines +73 to +76
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential information disclosure in error logging.

Logging the full response.text in error messages could expose sensitive information if the remote service returns detailed error responses (stack traces, internal paths, configuration details, etc.). Consider sanitizing or truncating the response text, or using the safe_error_message utility introduced in this PR.

🔎 Proposed fix to limit logged response content
             if response.status_code != status.HTTP_200_OK:  # pragma: no cover
                 logger.error(
-                    f"Failed to get openai docs from {url}, status code: {response.status_code}, response: {response.text}"
+                    f"Failed to get openai docs from {url}, status code: {response.status_code}, response: {response.text[:200]}"
                 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if response.status_code != status.HTTP_200_OK: # pragma: no cover
logger.error(
f"Failed to get openai docs from {url}, status code: {response.status_code}, response: {response.text}"
)
if response.status_code != status.HTTP_200_OK: # pragma: no cover
logger.error(
f"Failed to get openai docs from {url}, status code: {response.status_code}, response: {response.text[:200]}"
)
🤖 Prompt for AI Agents
In src/framex/plugins/proxy/__init__.py around lines 73 to 76, the error log
currently includes the full response.text which may leak sensitive data; replace
that direct inclusion with a sanitized/truncated value by using the
safe_error_message utility (or implement a short sanitizer) and log only the
sanitized message along with url and status code; if safe_error_message is not
imported in this file, add the import and ensure the sanitizer trims long bodies
and strips sensitive-looking content before passing to logger.error.

response.raise_for_status()
return cast(dict[str, Any], response.json())

Expand Down Expand Up @@ -162,7 +167,7 @@ async def register_proxy_func_route(
params=[("model", ProxyFuncHttpBody)],
handle=handle,
stream=False,
direct_output=True,
direct_output=False,
tags=[__plugin_meta__.name],
)

Expand Down
14 changes: 7 additions & 7 deletions src/framex/plugins/proxy/config.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from typing import Any, Self

from pydantic import BaseModel, model_validator
from pydantic import BaseModel, Field, model_validator

from framex.config import AuthConfig
from framex.plugin import get_plugin_config


class ProxyPluginConfig(BaseModel):
proxy_urls: list[str] = []
force_stream_apis: list[str] = []
proxy_urls: list[str] = Field(default_factory=list)
force_stream_apis: list[str] = Field(default_factory=list)
timeout: int = 600
ingress_config: dict[str, Any] = {"max_ongoing_requests": 60}
ingress_config: dict[str, Any] = Field(default_factory=lambda: {"max_ongoing_requests": 60})

white_list: list[str] = []
white_list: list[str] = Field(default_factory=list)

auth: AuthConfig = AuthConfig()
auth: AuthConfig = Field(default_factory=AuthConfig)

proxy_functions: dict[str, list[str]] = {}
proxy_functions: dict[str, list[str]] = Field(default_factory=dict)

def is_white_url(self, url: str) -> bool:
"""Check if a URL is protected by any auth_urls rule."""
Expand Down
8 changes: 8 additions & 0 deletions src/framex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,11 @@ def format_uptime(delta: timedelta) -> str:
parts.append(f"{seconds}s")

return " ".join(parts)


def safe_error_message(e: Exception) -> str:
if hasattr(e, "cause") and e.cause:
return str(e.cause)
if e.args:
return str(e.args[0])
return "Internal Server Error"
41 changes: 37 additions & 4 deletions tests/api/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,44 @@ def test_call_proxy_func(client: TestClient):
body = {"func_name": func, "data": data}
headers = {"Authorization": "i_am_local_proxy_auth_keys"}
res = client.post("/api/v1/proxy/remote", json=body, headers=headers).json()
res = cache_decode(res)
assert res["a_str"] == "test"
assert res["b_int"] == 123
model = res["c_model"]
assert res["status"] == 200
data = cache_decode(res["data"])
assert data["a_str"] == "test"
assert data["b_int"] == 123
model = data["c_model"]
assert model.id == "id_1"
assert model.name == 100
assert model.model.id == 1
assert model.model.name == "sub_name"


@pytest.mark.order(2)
def test_call_proxy_func_with_error_func_name(client: TestClient):
func = cache_encode("tests.test_plugins.error_func")
data = cache_encode(
{
"a_str": "test",
"b_int": 123,
"c_model": ExchangeModel(id="id_1", name=100, model=SubModel(id=1, name="sub_name")),
}
)
body = {"func_name": func, "data": data}
headers = {"Authorization": "i_am_local_proxy_auth_keys"}
res = client.post("/api/v1/proxy/remote", json=body, headers=headers).json()
assert res["status"] == 500


@pytest.mark.order(2)
def test_call_proxy_func_with_error_api_key(client: TestClient):
func = cache_encode("tests.test_plugins.local_exchange_key_value")
data = cache_encode(
{
"a_str": "test",
"b_int": 123,
"c_model": ExchangeModel(id="id_1", name=100, model=SubModel(id=1, name="sub_name")),
}
)
body = {"func_name": func, "data": data}
headers = {"Authorization": "i_am_error_keys"}
res = client.post("/api/v1/proxy/remote", json=body, headers=headers).json()
assert res["status"] == 401
2 changes: 2 additions & 0 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from pydantic import BaseModel, ConfigDict, Field

from framex.plugin.load import register_proxy_func
from framex.plugin.on import on_proxy


Expand Down Expand Up @@ -183,6 +184,7 @@ async def remote_exchange_key_value(a_str: str, b_int: int, c_model: ExchangeMod

@pytest.mark.order(1)
async def test_on_proxy_local_call():
await register_proxy_func(local_exchange_key_value)
res = await local_exchange_key_value(
a_str="test", b_int=123, c_model=ExchangeModel(id="id_1", name=100, model=SubModel(id=1, name="sub_name"))
)
Expand Down
31 changes: 30 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
from pydantic import BaseModel

from framex.config import AuthConfig
from framex.utils import StreamEnventType, cache_decode, cache_encode, format_uptime, make_stream_event
from framex.utils import (
StreamEnventType,
cache_decode,
cache_encode,
format_uptime,
make_stream_event,
safe_error_message,
)


class StreamDataModel(BaseModel):
Expand Down Expand Up @@ -152,3 +159,25 @@ def test_format_uptime():
# Test only minutes (no seconds)
delta = timedelta(minutes=5, seconds=0)
assert format_uptime(delta) == "5m"


class CauseError(Exception):
def __init__(self, cause: Exception):
self.cause = cause
super().__init__("outer")


def test_safe_error_message_with_cause():
e = CauseError(RuntimeError("inner"))
assert safe_error_message(e) == "inner"


def test_safe_error_message_with_args():
e = RuntimeError("simple error")
assert safe_error_message(e) == "simple error"


def test_safe_error_message_fallback():
e = Exception()
e.args = ()
assert safe_error_message(e) == "Internal Server Error"
Loading
Loading