Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"pydantic>=2.11.7",
"pydantic-settings>=2.10.1",
"pyjwt>=2.10.1",
"python-multipart>=0.0.22",
"pytz>=2025.2",
"sentry-sdk[fastapi]>=2.33.0",
"tomli>=2.2.1",
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ env =
server__enable_proxy=true
load_builtin_plugins=["echo","proxy","invoker"]
plugins__proxy__proxy_urls=["http://localhost:9527"]
plugins__proxy__white_list=["/api/v1/proxy/mock/info","/proxy/mock/get","/proxy/mock/post","/proxy/mock/post_model","/proxy/mock/auth/*"]
plugins__proxy__white_list=["/api/v1/proxy/mock/info","/proxy/mock/get","/proxy/mock/post","/proxy/mock/post_model","/proxy/mock/upload","/proxy/mock/auth/*"]
plugins__proxy__auth__rules={{"/proxy/mock/auth/*":["i_am_proxy_general_auth_keys"],"/api/v1/openapi.json":["i_am_proxy_docs_auth_keys"],"/proxy/mock/auth/sget":["i_am_proxy_special_auth_keys"],"/api/v1/proxy/remote":["i_am_proxy_func_auth_keys"]}}
plugins__proxy__proxy_functions={{"http://localhost:9527":["tests.test_plugins.remote_exchange_key_value"]}}
; plugins__proxy__force_stream_apis=[]
Expand Down
31 changes: 24 additions & 7 deletions src/framex/driver/ingress.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import inspect
import os
import re
from collections.abc import Callable
from enum import Enum
from typing import Any

from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import create_model
from starlette.routing import Route

from framex.adapter import get_adapter
Expand Down Expand Up @@ -73,7 +72,7 @@ def register_route(
handle: Any,
stream: bool = False,
direct_output: bool = False,
tags: list[str | Enum] | None = None,
tags: list[str] | None = None,
auth_keys: list[str] | None = None,
include_in_schema: bool = True,
**kwargs: Any,
Expand Down Expand Up @@ -101,23 +100,41 @@ def register_route(
return False
if (not path) or (not methods):
raise RuntimeError(f"Api({path}) or methods({methods}) is empty")
Model: BaseModel = create_model(f"{func_name}_InputModel", **{name: (tp, ...) for name, tp in params}) # type:ignore # noqa

async def route_handler(response: Response, model: Model = Depends()) -> Any: # type: ignore [valid-type]
async def route_handler(response: Response, **request_kwargs: Any) -> Any:
c_handle = getattr(handle, func_name)
if not c_handle:
raise RuntimeError(
f"No handle found for api({methods}): {path} from {handle.deployment_name}.{func_name}"
)
response.headers["X-Raw-Output"] = str(direct_output)
if stream:
gen = adapter._stream_call(c_handle, **(model.__dict__))
gen = adapter._stream_call(c_handle, **request_kwargs)
return StreamingResponse( # type: ignore
gen,
media_type="text/event-stream",
)

return await adapter._acall(c_handle, **model.__dict__) # type: ignore
return await adapter._acall(c_handle, **request_kwargs) # type: ignore

route_handler.__signature__ = inspect.Signature( # type: ignore
[
inspect.Parameter(
"response",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=Response,
),
*[
inspect.Parameter(
name,
inspect.Parameter.KEYWORD_ONLY,
annotation=tp,
default=inspect.Parameter.empty,
)
for name, tp in params
],
]
)

# Inject auth dependency if needed
dependencies = []
Expand Down
37 changes: 8 additions & 29 deletions src/framex/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,11 @@
from framex.log import logger
from framex.plugin.manage import _manager
from framex.plugin.model import Plugin, PluginApi
from framex.plugin.resolver import (
ApiResolver,
_set_default_api_resolver,
get_current_api_resolver,
get_current_remote_apis,
get_default_api_resolver,
)
from framex.plugin.resolver import coerce_plugin_api, get_current_remote_apis

C = TypeVar("C", bound=BaseModel)

_current_plugin: ContextVar[Optional["Plugin"]] = ContextVar("_current_plugin", default=None)
_set_default_api_resolver(ApiResolver(manager=_manager))


def get_plugin(plugin_id: str) -> Plugin | None:
Expand All @@ -47,7 +40,6 @@ def check_plugin_config_exists(plugin_name: str) -> bool:
@logger.catch()
def init_all_deployments(enable_proxy: bool) -> list[Any]:
deployments = []
all_apis = {**_manager.all_plugin_apis[ApiType.FUNC], **_manager.all_plugin_apis[ApiType.HTTP]}
for plugin in get_loaded_plugins():
for dep in plugin.deployments:
remote_apis = {
Expand All @@ -63,8 +55,8 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]:
call_type=ApiType.PROXY,
)
logger.opt(colors=True).warning(
f"Api(<r>{api_name}</r>) not found, "
f"plugin(<r>{dep.deployment}</r>) will "
f"Api(<y>{api_name}</y>) not found, "
f"plugin(<y>{dep.deployment}</y>) will "
f"use proxy plugin({PROXY_PLUGIN_NAME}) to transfer!"
)
else: # pragma: no cover
Expand All @@ -74,7 +66,6 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]:
deployment = get_adapter().bind(
dep.deployment,
remote_apis=remote_apis,
api_registry=all_apis,
config=plugin.config,
)

Expand All @@ -83,20 +74,12 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]:
return deployments


def _resolve_plugin_api(
api_name: str | PluginApi,
resolver: ApiResolver | None = None,
) -> tuple[PluginApi, bool]:
def _resolve_plugin_api(api_name: str | PluginApi) -> tuple[PluginApi, bool]:
current_remote_apis = get_current_remote_apis()
if isinstance(api_name, PluginApi):
return api_name, api_name.call_type == ApiType.PROXY

active_resolver = resolver or get_current_api_resolver() or get_default_api_resolver()
if current_remote_apis is not None:
api = active_resolver.coerce_plugin_api(current_remote_apis.get(api_name))
else:
api = active_resolver.resolve(api_name, None)

api = coerce_plugin_api(current_remote_apis.get(api_name)) if current_remote_apis is not None else None
if api is None:
if current_remote_apis is not None:
raise RuntimeError(
Expand Down Expand Up @@ -149,19 +132,15 @@ def _unwrap_plugin_call_result(api_name: str | PluginApi, result: Any, use_proxy
res = result.get("data")
status = result.get("status")
if status not in settings.server.legal_proxy_code:
logger.opt(colors=True).error(f"<>Proxy API {api_name} call illegal: <r>{result}</r>")
logger.opt(colors=True).error(f"Proxy API {api_name} call illegal: <r>{result}</r>")
raise RuntimeError(f"Proxy API {api_name} returned status {status}")
if res is None:
logger.opt(colors=True).warning(f"API {api_name} returned empty data")
return res


async def call_plugin_api(
api_name: str | PluginApi,
resolver: ApiResolver | None = None,
**kwargs: Any,
) -> Any:
api, use_proxy = _resolve_plugin_api(api_name, resolver=resolver)
async def call_plugin_api(api_name: str | PluginApi, **kwargs: Any) -> Any:
api, use_proxy = _resolve_plugin_api(api_name)
normalized_kwargs = _normalize_plugin_call_kwargs(api, kwargs)
result = await get_adapter().call_func(api, **normalized_kwargs)
return _unwrap_plugin_call_result(api_name, result, use_proxy)
Expand Down
28 changes: 6 additions & 22 deletions src/framex/plugin/base.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
import inspect
from collections.abc import Mapping
from functools import wraps
from typing import Any, final

from framex.config import settings
from framex.log import setup_logger
from framex.plugin import call_plugin_api
from framex.plugin.model import PluginApi
from framex.plugin.resolver import (
ApiResolver,
reset_current_api_resolver,
reset_current_remote_apis,
set_current_api_resolver,
set_current_remote_apis,
)
from framex.plugin.resolver import reset_current_remote_apis, set_current_remote_apis


class BasePlugin:
"""Base class for all plugins"""

def __init__(self, **kwargs: Any) -> None:
setup_logger()
self.remote_apis: dict[str, PluginApi] = kwargs.get("remote_apis", {})
self.api_registry: Mapping[str, PluginApi] = kwargs.get("api_registry", {})
self.api_resolver = ApiResolver(api_registry=self.api_registry)
self._bind_api_resolver_context()
self.remote_apis = kwargs.get("remote_apis", {})
self._bind_remote_api_context()
if settings.server.use_ray:
import asyncio

Expand All @@ -33,52 +23,46 @@ def __init__(self, **kwargs: Any) -> None:
async def on_start(self) -> None:
pass

def _bind_api_resolver_context(self) -> None:
def _bind_remote_api_context(self) -> None:
for name, func in inspect.getmembers(type(self), predicate=callable):
if not getattr(func, "_on_request", False):
continue
bound = getattr(self, name)
setattr(self, name, self._wrap_with_api_resolver(bound))
setattr(self, name, self._wrap_with_remote_api_context(bound))

def _wrap_with_api_resolver(self, func: Any) -> Any:
def _wrap_with_remote_api_context(self, func: Any) -> Any:
if inspect.isasyncgenfunction(func):

@wraps(func)
async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any:
resolver_token = set_current_api_resolver(self.api_resolver)
remote_token = set_current_remote_apis(self.remote_apis)
try:
async for chunk in func(*args, **kwargs):
yield chunk
finally:
reset_current_remote_apis(remote_token)
reset_current_api_resolver(resolver_token)

return async_gen_wrapper

if inspect.iscoroutinefunction(func):

@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
resolver_token = set_current_api_resolver(self.api_resolver)
remote_token = set_current_remote_apis(self.remote_apis)
try:
return await func(*args, **kwargs)
finally:
reset_current_remote_apis(remote_token)
reset_current_api_resolver(resolver_token)

return async_wrapper

@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
resolver_token = set_current_api_resolver(self.api_resolver)
remote_token = set_current_remote_apis(self.remote_apis)
try:
return func(*args, **kwargs)
finally:
reset_current_remote_apis(remote_token)
reset_current_api_resolver(resolver_token)

return sync_wrapper

Expand Down
1 change: 0 additions & 1 deletion src/framex/plugin/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ async def register_proxy_func(func: Callable) -> None:
)
await call_plugin_api(
api_reg,
None,
func_name=full_func_name,
func_callable=_PROXY_REGISTRY[full_func_name],
)
4 changes: 2 additions & 2 deletions src/framex/plugin/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum, StrEnum
from enum import StrEnum
from types import ModuleType
from typing import Any

Expand Down Expand Up @@ -33,7 +33,7 @@ class PluginApi(BaseModel):
methods: list[str] = Field(default_factory=lambda: ["POST"])
params: list[tuple[str, type[Any] | Callable[..., Any]]] = Field(default_factory=list)
call_type: ApiType = ApiType.HTTP
tags: list[str | Enum] | None = None
tags: list[str] | None = None
stream: bool = False
raw_response: bool = False
extend_kwargs: dict[str, Any] = Field(default_factory=dict)
Expand Down
6 changes: 3 additions & 3 deletions src/framex/plugin/on.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ def decorator(cls: type) -> type:
params = extract_method_params(func)
version: str = plugin.module.__plugin_meta__.version
version = f"v{version}" if not version.startswith("v") else version
tags = [f"{plugin.name}({version}): {plugin.module.__plugin_meta__.description}"]

if func.__tags:
tags.extend(func.__tags)
tags: list[str] = func.__tags
else:
tags = [f"{plugin.name}({version}): {plugin.module.__plugin_meta__.description}"]

plugin_apis.append(
PluginApi(
Expand Down Expand Up @@ -177,7 +178,6 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
)
res = await call_plugin_api(
api_call,
None,
func_name=cache_encode(full_func_name),
data=cache_encode(data=kwargs),
)
Expand Down
64 changes: 7 additions & 57 deletions src/framex/plugin/resolver.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,21 @@
from collections.abc import Mapping
from contextvars import ContextVar
from typing import Any, Protocol
from typing import Any

from framex.plugin.model import PluginApi


class SupportsApiLookup(Protocol):
def get_api(self, api_name: str) -> PluginApi | None: ...
def coerce_plugin_api(api: PluginApi | dict[str, Any] | None) -> PluginApi | None:
if api is None or isinstance(api, PluginApi):
return api
if isinstance(api, dict):
return PluginApi.model_validate(api)
return None


class ApiResolver:
def __init__(
self,
manager: SupportsApiLookup | None = None,
api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None,
) -> None:
self._manager = manager
self._api_registry = api_registry or {}

@staticmethod
def coerce_plugin_api(api: PluginApi | dict[str, Any] | None) -> PluginApi | None:
if api is None or isinstance(api, PluginApi):
return api
if isinstance(api, dict):
return PluginApi.model_validate(api)
return None

def resolve(
self,
api_name: str,
api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None,
) -> PluginApi | None:
if api_registry is not None and (api := self.coerce_plugin_api(api_registry.get(api_name))):
return api
if self._manager and (api := self._manager.get_api(api_name)):
return api
return self.coerce_plugin_api(self._api_registry.get(api_name))


_current_api_resolver: ContextVar[ApiResolver | None] = ContextVar("_current_api_resolver", default=None)
_current_remote_apis: ContextVar[Mapping[str, PluginApi | dict[str, Any]] | None] = ContextVar(
"_current_remote_apis", default=None
)
_default_api_resolver: ApiResolver | None = None


def get_current_api_resolver() -> ApiResolver | None:
return _current_api_resolver.get()


def get_default_api_resolver() -> ApiResolver:
if _default_api_resolver is None:
raise RuntimeError("Default API resolver is not configured")
return _default_api_resolver


def _set_default_api_resolver(resolver: ApiResolver) -> None:
global _default_api_resolver
_default_api_resolver = resolver


def set_current_api_resolver(resolver: ApiResolver | None) -> Any:
return _current_api_resolver.set(resolver)


def reset_current_api_resolver(token: Any) -> None:
_current_api_resolver.reset(token)


def get_current_remote_apis() -> Mapping[str, PluginApi | dict[str, Any]] | None:
Expand Down
Loading
Loading