diff --git a/pyproject.toml b/pyproject.toml
index a2b9d41..59a1e96 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,6 +14,7 @@ dependencies = [
"pydantic>=2.11.7",
"pydantic-settings>=2.10.1",
"pyjwt>=2.10.1",
+ "python-multipart>=0.0.22",
"pytz>=2025.2",
"sentry-sdk[fastapi]>=2.33.0",
"tomli>=2.2.1",
diff --git a/pytest.ini b/pytest.ini
index 261728c..4ee217e 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -24,7 +24,7 @@ env =
server__enable_proxy=true
load_builtin_plugins=["echo","proxy","invoker"]
plugins__proxy__proxy_urls=["http://localhost:9527"]
- plugins__proxy__white_list=["/api/v1/proxy/mock/info","/proxy/mock/get","/proxy/mock/post","/proxy/mock/post_model","/proxy/mock/auth/*"]
+ plugins__proxy__white_list=["/api/v1/proxy/mock/info","/proxy/mock/get","/proxy/mock/post","/proxy/mock/post_model","/proxy/mock/upload","/proxy/mock/auth/*"]
plugins__proxy__auth__rules={{"/proxy/mock/auth/*":["i_am_proxy_general_auth_keys"],"/api/v1/openapi.json":["i_am_proxy_docs_auth_keys"],"/proxy/mock/auth/sget":["i_am_proxy_special_auth_keys"],"/api/v1/proxy/remote":["i_am_proxy_func_auth_keys"]}}
plugins__proxy__proxy_functions={{"http://localhost:9527":["tests.test_plugins.remote_exchange_key_value"]}}
; plugins__proxy__force_stream_apis=[]
diff --git a/src/framex/driver/ingress.py b/src/framex/driver/ingress.py
index fa49f64..8fa7af4 100644
--- a/src/framex/driver/ingress.py
+++ b/src/framex/driver/ingress.py
@@ -1,13 +1,12 @@
+import inspect
import os
import re
from collections.abc import Callable
-from enum import Enum
from typing import Any
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
-from pydantic import create_model
from starlette.routing import Route
from framex.adapter import get_adapter
@@ -73,7 +72,7 @@ def register_route(
handle: Any,
stream: bool = False,
direct_output: bool = False,
- tags: list[str | Enum] | None = None,
+ tags: list[str] | None = None,
auth_keys: list[str] | None = None,
include_in_schema: bool = True,
**kwargs: Any,
@@ -101,9 +100,8 @@ def register_route(
return False
if (not path) or (not methods):
raise RuntimeError(f"Api({path}) or methods({methods}) is empty")
- Model: BaseModel = create_model(f"{func_name}_InputModel", **{name: (tp, ...) for name, tp in params}) # type:ignore # noqa
- async def route_handler(response: Response, model: Model = Depends()) -> Any: # type: ignore [valid-type]
+ async def route_handler(response: Response, **request_kwargs: Any) -> Any:
c_handle = getattr(handle, func_name)
if not c_handle:
raise RuntimeError(
@@ -111,13 +109,32 @@ async def route_handler(response: Response, model: Model = Depends()) -> Any: #
)
response.headers["X-Raw-Output"] = str(direct_output)
if stream:
- gen = adapter._stream_call(c_handle, **(model.__dict__))
+ gen = adapter._stream_call(c_handle, **request_kwargs)
return StreamingResponse( # type: ignore
gen,
media_type="text/event-stream",
)
- return await adapter._acall(c_handle, **model.__dict__) # type: ignore
+ return await adapter._acall(c_handle, **request_kwargs) # type: ignore
+
+ route_handler.__signature__ = inspect.Signature( # type: ignore
+ [
+ inspect.Parameter(
+ "response",
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ annotation=Response,
+ ),
+ *[
+ inspect.Parameter(
+ name,
+ inspect.Parameter.KEYWORD_ONLY,
+ annotation=tp,
+ default=inspect.Parameter.empty,
+ )
+ for name, tp in params
+ ],
+ ]
+ )
# Inject auth dependency if needed
dependencies = []
diff --git a/src/framex/plugin/__init__.py b/src/framex/plugin/__init__.py
index 728afa6..11f8d23 100644
--- a/src/framex/plugin/__init__.py
+++ b/src/framex/plugin/__init__.py
@@ -10,18 +10,11 @@
from framex.log import logger
from framex.plugin.manage import _manager
from framex.plugin.model import Plugin, PluginApi
-from framex.plugin.resolver import (
- ApiResolver,
- _set_default_api_resolver,
- get_current_api_resolver,
- get_current_remote_apis,
- get_default_api_resolver,
-)
+from framex.plugin.resolver import coerce_plugin_api, get_current_remote_apis
C = TypeVar("C", bound=BaseModel)
_current_plugin: ContextVar[Optional["Plugin"]] = ContextVar("_current_plugin", default=None)
-_set_default_api_resolver(ApiResolver(manager=_manager))
def get_plugin(plugin_id: str) -> Plugin | None:
@@ -47,7 +40,6 @@ def check_plugin_config_exists(plugin_name: str) -> bool:
@logger.catch()
def init_all_deployments(enable_proxy: bool) -> list[Any]:
deployments = []
- all_apis = {**_manager.all_plugin_apis[ApiType.FUNC], **_manager.all_plugin_apis[ApiType.HTTP]}
for plugin in get_loaded_plugins():
for dep in plugin.deployments:
remote_apis = {
@@ -63,8 +55,8 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]:
call_type=ApiType.PROXY,
)
logger.opt(colors=True).warning(
- f"Api({api_name}) not found, "
- f"plugin({dep.deployment}) will "
+ f"Api({api_name}) not found, "
+ f"plugin({dep.deployment}) will "
f"use proxy plugin({PROXY_PLUGIN_NAME}) to transfer!"
)
else: # pragma: no cover
@@ -74,7 +66,6 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]:
deployment = get_adapter().bind(
dep.deployment,
remote_apis=remote_apis,
- api_registry=all_apis,
config=plugin.config,
)
@@ -83,20 +74,12 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]:
return deployments
-def _resolve_plugin_api(
- api_name: str | PluginApi,
- resolver: ApiResolver | None = None,
-) -> tuple[PluginApi, bool]:
+def _resolve_plugin_api(api_name: str | PluginApi) -> tuple[PluginApi, bool]:
current_remote_apis = get_current_remote_apis()
if isinstance(api_name, PluginApi):
return api_name, api_name.call_type == ApiType.PROXY
- active_resolver = resolver or get_current_api_resolver() or get_default_api_resolver()
- if current_remote_apis is not None:
- api = active_resolver.coerce_plugin_api(current_remote_apis.get(api_name))
- else:
- api = active_resolver.resolve(api_name, None)
-
+ api = coerce_plugin_api(current_remote_apis.get(api_name)) if current_remote_apis is not None else None
if api is None:
if current_remote_apis is not None:
raise RuntimeError(
@@ -149,19 +132,15 @@ def _unwrap_plugin_call_result(api_name: str | PluginApi, result: Any, use_proxy
res = result.get("data")
status = result.get("status")
if status not in settings.server.legal_proxy_code:
- logger.opt(colors=True).error(f"<>Proxy API {api_name} call illegal: {result}")
+ logger.opt(colors=True).error(f"Proxy API {api_name} call illegal: {result}")
raise RuntimeError(f"Proxy API {api_name} returned status {status}")
if res is None:
logger.opt(colors=True).warning(f"API {api_name} returned empty data")
return res
-async def call_plugin_api(
- api_name: str | PluginApi,
- resolver: ApiResolver | None = None,
- **kwargs: Any,
-) -> Any:
- api, use_proxy = _resolve_plugin_api(api_name, resolver=resolver)
+async def call_plugin_api(api_name: str | PluginApi, **kwargs: Any) -> Any:
+ api, use_proxy = _resolve_plugin_api(api_name)
normalized_kwargs = _normalize_plugin_call_kwargs(api, kwargs)
result = await get_adapter().call_func(api, **normalized_kwargs)
return _unwrap_plugin_call_result(api_name, result, use_proxy)
diff --git a/src/framex/plugin/base.py b/src/framex/plugin/base.py
index 0042e67..a784823 100644
--- a/src/framex/plugin/base.py
+++ b/src/framex/plugin/base.py
@@ -1,19 +1,11 @@
import inspect
-from collections.abc import Mapping
from functools import wraps
from typing import Any, final
from framex.config import settings
from framex.log import setup_logger
from framex.plugin import call_plugin_api
-from framex.plugin.model import PluginApi
-from framex.plugin.resolver import (
- ApiResolver,
- reset_current_api_resolver,
- reset_current_remote_apis,
- set_current_api_resolver,
- set_current_remote_apis,
-)
+from framex.plugin.resolver import reset_current_remote_apis, set_current_remote_apis
class BasePlugin:
@@ -21,10 +13,8 @@ class BasePlugin:
def __init__(self, **kwargs: Any) -> None:
setup_logger()
- self.remote_apis: dict[str, PluginApi] = kwargs.get("remote_apis", {})
- self.api_registry: Mapping[str, PluginApi] = kwargs.get("api_registry", {})
- self.api_resolver = ApiResolver(api_registry=self.api_registry)
- self._bind_api_resolver_context()
+ self.remote_apis = kwargs.get("remote_apis", {})
+ self._bind_remote_api_context()
if settings.server.use_ray:
import asyncio
@@ -33,26 +23,24 @@ def __init__(self, **kwargs: Any) -> None:
async def on_start(self) -> None:
pass
- def _bind_api_resolver_context(self) -> None:
+ def _bind_remote_api_context(self) -> None:
for name, func in inspect.getmembers(type(self), predicate=callable):
if not getattr(func, "_on_request", False):
continue
bound = getattr(self, name)
- setattr(self, name, self._wrap_with_api_resolver(bound))
+ setattr(self, name, self._wrap_with_remote_api_context(bound))
- def _wrap_with_api_resolver(self, func: Any) -> Any:
+ def _wrap_with_remote_api_context(self, func: Any) -> Any:
if inspect.isasyncgenfunction(func):
@wraps(func)
async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any:
- resolver_token = set_current_api_resolver(self.api_resolver)
remote_token = set_current_remote_apis(self.remote_apis)
try:
async for chunk in func(*args, **kwargs):
yield chunk
finally:
reset_current_remote_apis(remote_token)
- reset_current_api_resolver(resolver_token)
return async_gen_wrapper
@@ -60,25 +48,21 @@ async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any:
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
- resolver_token = set_current_api_resolver(self.api_resolver)
remote_token = set_current_remote_apis(self.remote_apis)
try:
return await func(*args, **kwargs)
finally:
reset_current_remote_apis(remote_token)
- reset_current_api_resolver(resolver_token)
return async_wrapper
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
- resolver_token = set_current_api_resolver(self.api_resolver)
remote_token = set_current_remote_apis(self.remote_apis)
try:
return func(*args, **kwargs)
finally:
reset_current_remote_apis(remote_token)
- reset_current_api_resolver(resolver_token)
return sync_wrapper
diff --git a/src/framex/plugin/load.py b/src/framex/plugin/load.py
index d260f49..50bf3c9 100644
--- a/src/framex/plugin/load.py
+++ b/src/framex/plugin/load.py
@@ -56,7 +56,6 @@ async def register_proxy_func(func: Callable) -> None:
)
await call_plugin_api(
api_reg,
- None,
func_name=full_func_name,
func_callable=_PROXY_REGISTRY[full_func_name],
)
diff --git a/src/framex/plugin/model.py b/src/framex/plugin/model.py
index 7dc0c96..c923374 100644
--- a/src/framex/plugin/model.py
+++ b/src/framex/plugin/model.py
@@ -1,6 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass, field
-from enum import Enum, StrEnum
+from enum import StrEnum
from types import ModuleType
from typing import Any
@@ -33,7 +33,7 @@ class PluginApi(BaseModel):
methods: list[str] = Field(default_factory=lambda: ["POST"])
params: list[tuple[str, type[Any] | Callable[..., Any]]] = Field(default_factory=list)
call_type: ApiType = ApiType.HTTP
- tags: list[str | Enum] | None = None
+ tags: list[str] | None = None
stream: bool = False
raw_response: bool = False
extend_kwargs: dict[str, Any] = Field(default_factory=dict)
diff --git a/src/framex/plugin/on.py b/src/framex/plugin/on.py
index 668f4bf..f42c60a 100644
--- a/src/framex/plugin/on.py
+++ b/src/framex/plugin/on.py
@@ -40,10 +40,11 @@ def decorator(cls: type) -> type:
params = extract_method_params(func)
version: str = plugin.module.__plugin_meta__.version
version = f"v{version}" if not version.startswith("v") else version
- tags = [f"{plugin.name}({version}): {plugin.module.__plugin_meta__.description}"]
if func.__tags:
- tags.extend(func.__tags)
+ tags: list[str] = func.__tags
+ else:
+ tags = [f"{plugin.name}({version}): {plugin.module.__plugin_meta__.description}"]
plugin_apis.append(
PluginApi(
@@ -177,7 +178,6 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
)
res = await call_plugin_api(
api_call,
- None,
func_name=cache_encode(full_func_name),
data=cache_encode(data=kwargs),
)
diff --git a/src/framex/plugin/resolver.py b/src/framex/plugin/resolver.py
index 1e5ddb8..b68e59f 100644
--- a/src/framex/plugin/resolver.py
+++ b/src/framex/plugin/resolver.py
@@ -1,71 +1,21 @@
from collections.abc import Mapping
from contextvars import ContextVar
-from typing import Any, Protocol
+from typing import Any
from framex.plugin.model import PluginApi
-class SupportsApiLookup(Protocol):
- def get_api(self, api_name: str) -> PluginApi | None: ...
+def coerce_plugin_api(api: PluginApi | dict[str, Any] | None) -> PluginApi | None:
+ if api is None or isinstance(api, PluginApi):
+ return api
+ if isinstance(api, dict):
+ return PluginApi.model_validate(api)
+ return None
-class ApiResolver:
- def __init__(
- self,
- manager: SupportsApiLookup | None = None,
- api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None,
- ) -> None:
- self._manager = manager
- self._api_registry = api_registry or {}
-
- @staticmethod
- def coerce_plugin_api(api: PluginApi | dict[str, Any] | None) -> PluginApi | None:
- if api is None or isinstance(api, PluginApi):
- return api
- if isinstance(api, dict):
- return PluginApi.model_validate(api)
- return None
-
- def resolve(
- self,
- api_name: str,
- api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None,
- ) -> PluginApi | None:
- if api_registry is not None and (api := self.coerce_plugin_api(api_registry.get(api_name))):
- return api
- if self._manager and (api := self._manager.get_api(api_name)):
- return api
- return self.coerce_plugin_api(self._api_registry.get(api_name))
-
-
-_current_api_resolver: ContextVar[ApiResolver | None] = ContextVar("_current_api_resolver", default=None)
_current_remote_apis: ContextVar[Mapping[str, PluginApi | dict[str, Any]] | None] = ContextVar(
"_current_remote_apis", default=None
)
-_default_api_resolver: ApiResolver | None = None
-
-
-def get_current_api_resolver() -> ApiResolver | None:
- return _current_api_resolver.get()
-
-
-def get_default_api_resolver() -> ApiResolver:
- if _default_api_resolver is None:
- raise RuntimeError("Default API resolver is not configured")
- return _default_api_resolver
-
-
-def _set_default_api_resolver(resolver: ApiResolver) -> None:
- global _default_api_resolver
- _default_api_resolver = resolver
-
-
-def set_current_api_resolver(resolver: ApiResolver | None) -> Any:
- return _current_api_resolver.set(resolver)
-
-
-def reset_current_api_resolver(token: Any) -> None:
- _current_api_resolver.reset(token)
def get_current_remote_apis() -> Mapping[str, PluginApi | dict[str, Any]] | None:
diff --git a/src/framex/plugins/proxy/__init__.py b/src/framex/plugins/proxy/__init__.py
index 509edf2..f9c6972 100644
--- a/src/framex/plugins/proxy/__init__.py
+++ b/src/framex/plugins/proxy/__init__.py
@@ -15,7 +15,14 @@
from framex.plugin import BasePlugin, PluginApi, PluginMetadata, on_register
from framex.plugin.model import ApiType
from framex.plugin.on import on_request
-from framex.plugins.proxy.builder import create_pydantic_model, format_proxy_params, type_map
+from framex.plugins.proxy.builder import (
+ create_pydantic_model,
+ format_proxy_params,
+ is_upload_annotation,
+ resolve_annotation,
+ to_multipart_annotation,
+ type_map,
+)
from framex.plugins.proxy.config import ProxyPluginConfig, settings
from framex.plugins.proxy.model import ProxyFunc, ProxyFuncHttpBody
from framex.utils import cache_decode, cache_encode, shorten_str
@@ -48,6 +55,7 @@ async def on_start(self) -> None:
return
for url in settings.proxy_url_list:
+ logger.info(f"Try to parse openapi docs from {url}")
await self._parse_openai_docs(url)
if settings.proxy_functions:
@@ -96,6 +104,7 @@ async def _parse_openai_docs(self, url: str) -> None:
headers = None
for method, body in details.items():
+ func_name = body.get("operationId")
# Process request parameters
params: list[tuple[str, Any]] = [
(name, c_type)
@@ -104,26 +113,54 @@ async def _parse_openai_docs(self, url: str) -> None:
and (typ := param.get("schema").get("type"))
and (c_type := type_map.get(typ))
]
+ body_param_names: set[str] = set()
+ file_param_names: set[str] = set()
# Process request body
if request_body := body.get("requestBody"):
- schema_name = (
- request_body.get("content", {})
- .get("application/json", {})
- .get("schema", {})
- .get("$ref", "")
- .rsplit("/", 1)[-1]
- )
- if not (model_schema := components.get(schema_name)): # pragma: no cover
- raise ValueError(f"Schema '{schema_name}' not found in components.")
-
- Model = create_pydantic_model(schema_name, model_schema, components) # noqa
- params.append(("model", Model))
+ body_content = request_body.get("content", {})
+ if "application/json" in body_content:
+ content_type = "application/json"
+ elif "multipart/form-data" in body_content:
+ content_type = "multipart/form-data"
+ else:
+ logger.opt(colors=True).error(
+ f"Failed to proxy api({method}) {url}{path}, unsupported content type: ${body_content.keys()}"
+ )
+ continue
+
+ schema = body_content.get(content_type, {}).get("schema", {})
+ if (schema_name := schema.get("$ref", {}).rsplit("/", 1)[-1]) and not (
+ model_schema := components.get(schema_name)
+ ):
+ logger.opt(colors=True).error(
+ f"Failed to proxy api({method}) {url}{path}, schema '{schema_name}' not found in components"
+ )
+ continue
+
+ if content_type == "application/json":
+ Model = create_pydantic_model(schema_name, model_schema, components) # noqa
+ params.append(("model", Model))
+ body_param_names.add("model")
+ else:
+ for field_name, prop_schema in model_schema.get("properties", {}).items():
+ annotation = resolve_annotation(prop_schema, components)
+ params.append((field_name, to_multipart_annotation(annotation)))
+ body_param_names.add(field_name)
+ if is_upload_annotation(annotation):
+ file_param_names.add(field_name)
+
logger.opt(colors=True).trace(f"Found proxy api({method}) {url}{path}")
- func_name = body.get("operationId")
is_stream = path in settings.force_stream_apis
func = self._create_dynamic_method(
- func_name, method, params, f"{url}{path}", stream=is_stream, headers=headers
+ func_name,
+ method,
+ params,
+ f"{url}{path}",
+ body_param_names=body_param_names,
+ file_param_names=file_param_names,
+ stream=is_stream,
+ headers=headers,
)
setattr(self, func_name, func)
@@ -196,6 +233,7 @@ async def __call__(self, proxy_path: str, **kwargs: Any) -> Any:
return await func(**kwargs)
raise RuntimeError(f"api({proxy_path}) not found")
+ # @logger.catch
async def fetch_response(
self,
stream: bool = False,
@@ -226,12 +264,16 @@ def _create_dynamic_method(
method: str,
params: list[tuple[str, type]],
url: str,
+ body_param_names: set[str] | None = None,
+ file_param_names: set[str] | None = None,
stream: bool = False,
headers: dict[str, str] | None = None,
) -> Callable[..., Any]:
# Build a Pydantic request model (for data validation)
model_name = f"{func_name.title()}_RequestModel"
RequestModel = create_model(model_name, **{k: (t, ...) for k, t in params}) # type: ignore # noqa
+ body_param_names = body_param_names or set()
+ file_param_names = file_param_names or set()
# Construct dynamic methods
async def dynamic_method(**kwargs: Any) -> AsyncGenerator[str, None] | dict[str, Any] | str:
@@ -240,19 +282,50 @@ async def dynamic_method(**kwargs: Any) -> AsyncGenerator[str, None] | dict[str,
validated = RequestModel(**kwargs) # Type Validation
query = {}
json_body = None
+ form_body = {}
+ files = []
for field_name, value in validated:
- if isinstance(value, BaseModel):
+ if field_name in body_param_names:
+ if field_name == "model" and isinstance(value, BaseModel):
+ json_body = value.model_dump()
+ elif field_name in file_param_names:
+ upload_values = value if isinstance(value, list) else [value]
+ for upload in upload_values:
+ upload.file.seek(0)
+ files.append(
+ (
+ field_name,
+ (
+ upload.filename or field_name,
+ upload.file,
+ upload.content_type or "application/octet-stream",
+ ),
+ )
+ )
+ else:
+ form_body[field_name] = value
+ elif isinstance(value, BaseModel):
json_body = value.model_dump()
else:
query[field_name] = value
try:
+ request_kwargs: dict[str, Any] = {
+ "stream": stream,
+ "method": method.upper(),
+ "url": url,
+ "params": query,
+ "headers": headers,
+ }
+ if method.upper() != "GET":
+ if files or form_body:
+ if form_body:
+ request_kwargs["data"] = form_body
+ if files:
+ request_kwargs["files"] = files
+ elif json_body is not None:
+ request_kwargs["json"] = json_body
return await self.fetch_response(
- stream=stream,
- method=method.upper(),
- url=url,
- params=query,
- json=json_body if method.upper() != "GET" else None,
- headers=headers,
+ **request_kwargs,
)
except Exception as e:
logger.opt(exception=e, colors=True).error(f"Error calling proxy api({method}) {url}: {e}")
diff --git a/src/framex/plugins/proxy/builder.py b/src/framex/plugins/proxy/builder.py
index 29657c0..5177415 100644
--- a/src/framex/plugins/proxy/builder.py
+++ b/src/framex/plugins/proxy/builder.py
@@ -1,5 +1,6 @@
-from typing import Any, Union
+from typing import Annotated, Any, Union, get_args, get_origin
+from fastapi import File, Form, UploadFile
from pydantic import BaseModel, create_model
from framex.plugins.proxy.model import ProxyFuncHttpBody
@@ -16,7 +17,27 @@
"object": dict,
"null": None,
}
-from typing import get_args, get_origin
+
+
+def unwrap_annotation(annotation: Any) -> Any:
+ if get_origin(annotation) is Annotated:
+ return get_args(annotation)[0]
+ return annotation
+
+
+def is_upload_annotation(annotation: Any) -> bool:
+ annotation = unwrap_annotation(annotation)
+ origin = get_origin(annotation)
+ if origin is list:
+ args = get_args(annotation)
+ return len(args) == 1 and is_upload_annotation(args[0])
+ return annotation is UploadFile
+
+
+def to_multipart_annotation(annotation: Any) -> Any:
+ if is_upload_annotation(annotation):
+ return Annotated[annotation, File(...)]
+ return Annotated[annotation, Form(...)]
def resolve_annotation(
@@ -54,9 +75,18 @@ def resolve_annotation(
if "$ref" in prop_schema:
item_type = resolve_annotation(prop_schema, components)
else:
+ if prop_schema.get("type") == "string" and (
+ prop_schema.get("contentMediaType") == "application/octet-stream"
+ or prop_schema.get("format") == "binary"
+ ):
+ return list[UploadFile] # type: ignore [valid-type]
item_type = type_map.get(prop_schema.get("type", "string"), str)
return list[item_type] # type: ignore [valid-type]
if typ:
+ if typ == "string" and (
+ prop_schema.get("contentMediaType") == "application/octet-stream" or prop_schema.get("format") == "binary"
+ ):
+ return UploadFile
return type_map.get(typ, str)
raise RuntimeError(f"Unsupported prop_schema: {prop_schema}")
diff --git a/tests/adapter/test_ray_adapter.py b/tests/adapter/test_ray_adapter.py
index 46c8553..2503334 100644
--- a/tests/adapter/test_ray_adapter.py
+++ b/tests/adapter/test_ray_adapter.py
@@ -180,7 +180,12 @@ def capture_remote(func):
# Now test the captured wrapper
assert captured_wrapper is not None
with patch("asyncio.run") as mock_asyncio_run:
- mock_asyncio_run.return_value = 10
+
+ def _consume_coroutine(coro):
+ coro.close()
+ return 10
+
+ mock_asyncio_run.side_effect = _consume_coroutine
result = captured_wrapper(5)
mock_asyncio_run.assert_called_once()
assert result == 10
diff --git a/tests/api/test_proxy.py b/tests/api/test_proxy.py
index 5ce5049..1e5b160 100644
--- a/tests/api/test_proxy.py
+++ b/tests/api/test_proxy.py
@@ -29,6 +29,46 @@ def test_get_proxy_post_model(client: TestClient):
assert res == {"method": "POST", "body": data}
+def test_get_proxy_upload(client: TestClient):
+ params = {"message": "hello world"}
+ data = {"note": "upload note"}
+ files = {
+ "ppt_file": (
+ "demo.pptx",
+ b"ppt-content",
+ "application/vnd.openxmlformats-officedocument.presentationml.presentation",
+ ),
+ "txt_file": ("demo.txt", b"txt-content", "text/plain"),
+ }
+ res = client.post("/proxy/mock/upload", params=params, data=data, files=files).json()
+ assert res == {
+ "method": "POST",
+ "params": params,
+ "body": data,
+ "files": [
+ {
+ "field": "ppt_file",
+ "filename": "demo.pptx",
+ "content_type": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
+ },
+ {
+ "field": "txt_file",
+ "filename": "demo.txt",
+ "content_type": "text/plain",
+ },
+ ],
+ }
+
+
+def test_get_proxy_upload_openapi(client: TestClient):
+ data = client.get("/api/v1/openapi.json").json()
+ post = data["paths"]["/proxy/mock/upload"]["post"]
+ ref = post["requestBody"]["content"]["multipart/form-data"]["schema"]["$ref"].split("/")[-1]
+ schema = data["components"]["schemas"][ref]
+ assert schema["properties"]["ppt_file"]["format"] == "binary"
+ assert schema["properties"]["txt_file"]["format"] == "binary"
+
+
def test_get_proxy_black_get(client: TestClient):
res = client.get("/proxy/mock/black_get").json()
assert res["status"] == 404
diff --git a/tests/consts.py b/tests/consts.py
index ff78640..44cb188 100644
--- a/tests/consts.py
+++ b/tests/consts.py
@@ -70,6 +70,34 @@
},
}
},
+ "/proxy/mock/upload": {
+ "post": {
+ "tags": ["proxy"],
+ "summary": "Proxy Mock Upload",
+ "operationId": "proxy_mock_upload",
+ "parameters": [
+ {
+ "name": "message",
+ "in": "query",
+ "required": True,
+ "schema": {
+ "type": "string",
+ "title": "Message",
+ },
+ }
+ ],
+ "requestBody": {
+ "required": True,
+ "content": {"multipart/form-data": {"schema": {"$ref": "#/components/schemas/MockUploadBody"}}},
+ },
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {"application/json": {"schema": {"type": "object"}}},
+ }
+ },
+ }
+ },
"/proxy/mock/black_get": {
"get": {
"tags": ["proxy"],
@@ -197,7 +225,28 @@
"default": "default",
},
},
- }
+ },
+ "MockUploadBody": {
+ "title": "MockUploadBody",
+ "type": "object",
+ "required": ["note", "ppt_file", "txt_file"],
+ "properties": {
+ "note": {
+ "type": "string",
+ "title": "Note",
+ },
+ "ppt_file": {
+ "type": "string",
+ "contentMediaType": "application/octet-stream",
+ "title": "Ppt File",
+ },
+ "txt_file": {
+ "type": "string",
+ "format": "binary",
+ "title": "Txt File",
+ },
+ },
+ },
}
},
}
diff --git a/tests/driver/test_auth.py b/tests/driver/test_auth.py
index 87d7317..ea50b52 100644
--- a/tests/driver/test_auth.py
+++ b/tests/driver/test_auth.py
@@ -1,3 +1,4 @@
+import uuid
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
@@ -14,6 +15,8 @@
from framex.driver.application import create_fastapi_application
from framex.driver.auth import auth_jwt, authenticate, create_jwt, oauth_callback
+JWT_SECRET = uuid.uuid4().hex
+
# =========================================================
# helpers
# =========================================================
@@ -28,7 +31,7 @@ def fake_oauth(**overrides):
client_secret="secret", # noqa: S106
redirect_uri="/oauth/callback",
call_back_url="http://test/callback",
- jwt_secret="secret", # noqa: S106
+ jwt_secret=JWT_SECRET,
jwt_algorithm="HS256",
)
data.update(overrides)
@@ -44,7 +47,7 @@ class TestCreateJWT:
def test_create_jwt_success(self):
with patch("framex.config.settings.auth.oauth", fake_oauth()):
token = create_jwt({"username": "test"})
- decoded = jwt.decode(token, "secret", algorithms=["HS256"])
+ decoded = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
assert decoded["username"] == "test"
assert "iat" in decoded
assert "exp" in decoded
@@ -67,7 +70,7 @@ def test_returns_false_when_oauth_not_configured(self):
def test_returns_false_when_no_token_cookie(self):
with patch("framex.config.settings.auth.oauth") as mock_oauth:
- mock_oauth.jwt_secret = "secret" # noqa: S105
+ mock_oauth.jwt_secret = JWT_SECRET
mock_oauth.jwt_algorithm = "HS256"
req = Mock(spec=Request)
@@ -77,7 +80,7 @@ def test_returns_false_when_no_token_cookie(self):
def test_returns_true_when_token_is_valid(self):
with patch("framex.config.settings.auth.oauth") as mock_oauth:
- mock_oauth.jwt_secret = "secret" # noqa: S105
+ mock_oauth.jwt_secret = JWT_SECRET
mock_oauth.jwt_algorithm = "HS256"
now = datetime.now(UTC)
@@ -87,7 +90,7 @@ def test_returns_true_when_token_is_valid(self):
"iat": int(now.timestamp()),
"exp": int((now + timedelta(hours=1)).timestamp()),
},
- "secret",
+ JWT_SECRET,
algorithm="HS256",
)
@@ -98,7 +101,7 @@ def test_returns_true_when_token_is_valid(self):
def test_returns_false_when_token_is_invalid(self):
with patch("framex.config.settings.auth.oauth") as mock_oauth:
- mock_oauth.jwt_secret = "secret" # noqa: S105
+ mock_oauth.jwt_secret = JWT_SECRET
mock_oauth.jwt_algorithm = "HS256"
req = Mock(spec=Request)
@@ -108,7 +111,7 @@ def test_returns_false_when_token_is_invalid(self):
def test_returns_false_when_token_is_expired(self):
with patch("framex.config.settings.auth.oauth") as mock_oauth:
- mock_oauth.jwt_secret = "secret" # noqa: S105
+ mock_oauth.jwt_secret = JWT_SECRET
mock_oauth.jwt_algorithm = "HS256"
now = datetime.now(UTC)
@@ -118,7 +121,7 @@ def test_returns_false_when_token_is_expired(self):
"iat": int((now - timedelta(days=2)).timestamp()),
"exp": int((now - timedelta(days=1)).timestamp()),
},
- "secret",
+ JWT_SECRET,
algorithm="HS256",
)
@@ -213,7 +216,7 @@ def test_docs_accessible_with_valid_jwt(self):
"iat": int(now.timestamp()),
"exp": int((now + timedelta(hours=1)).timestamp()),
},
- "secret",
+ JWT_SECRET,
algorithm="HS256",
)
client.cookies.set("token", token)
diff --git a/tests/mock.py b/tests/mock.py
index 22e8eae..5cbfb04 100644
--- a/tests/mock.py
+++ b/tests/mock.py
@@ -26,6 +26,7 @@ async def mock_get(_, url: str, *__, **kwargs: Any):
async def mock_request(_, method: str, url: str, **kwargs: Any):
params = kwargs.get("params")
body = kwargs.get("json") or kwargs.get("data")
+ files = kwargs.get("files")
headers = kwargs.get("headers", {})
resp = MagicMock()
@@ -45,6 +46,20 @@ async def mock_request(_, method: str, url: str, **kwargs: Any):
"method": "POST",
"body": body,
}
+ elif url.endswith("/proxy/mock/upload") and method == "POST":
+ resp.json.return_value = {
+ "method": "POST",
+ "params": params,
+ "body": body,
+ "files": [
+ {
+ "field": field_name,
+ "filename": file_info[0],
+ "content_type": file_info[2],
+ }
+ for field_name, file_info in (files or [])
+ ],
+ }
elif url.endswith("/proxy/mock/info") and method == "GET":
resp.json.return_value = {
"info": "i_am_mock_proxy_info",
diff --git a/tests/test_plugin.py b/tests/test_plugin.py
index b2ff485..2b27031 100644
--- a/tests/test_plugin.py
+++ b/tests/test_plugin.py
@@ -8,17 +8,10 @@
from framex.consts import PROXY_PLUGIN_NAME, VERSION
from framex.plugin import call_plugin_api
from framex.plugin.model import ApiType, PluginApi
-from framex.plugin.resolver import (
- ApiResolver,
- reset_current_api_resolver,
- reset_current_remote_apis,
- set_current_api_resolver,
- set_current_remote_apis,
-)
+from framex.plugin.resolver import reset_current_remote_apis, set_current_remote_apis
def test_get_plugin():
- # check simple plugin
plugin = framex.get_plugin("export")
assert plugin
assert plugin.version == VERSION
@@ -30,108 +23,51 @@ def test_get_plugin():
class SampleModel(BaseModel):
- """Sample model for testing parameter conversion."""
-
field1: str
field2: int
class TestCallPluginApi:
- """Comprehensive tests for call_plugin_api function with proxy handling."""
-
- @pytest.mark.asyncio
- async def test_call_plugin_api_with_existing_api(self):
- """Test calling an API that exists in the manager."""
- # Setup
- api = PluginApi(api="test_api", deployment_name="test_deployment", params=[("param1", str), ("param2", int)])
-
- with (
- patch("framex.plugin._manager.get_api", return_value=api),
- patch("framex.plugin.get_adapter") as mock_adapter,
- ):
- mock_adapter.return_value.call_func = AsyncMock(return_value="test_result")
-
- # Execute
- result = await call_plugin_api("test_api", param1="value1", param2=42)
-
- # Assert
- assert result == "test_result"
- mock_adapter.return_value.call_func.assert_called_once()
-
- @pytest.mark.asyncio
- async def test_call_plugin_api_with_basemodel_result(self):
- """Test that BaseModel results are converted to dict with aliases."""
- api = PluginApi(api="test_api", deployment_name="test_deployment")
- model_result = SampleModel(field1="test", field2=123)
-
- with (
- patch("framex.plugin._manager.get_api", return_value=api),
- patch("framex.plugin.get_adapter") as mock_adapter,
- ):
- mock_adapter.return_value.call_func = AsyncMock(return_value=model_result)
-
- result = await call_plugin_api("test_api")
-
- assert isinstance(result, dict)
- assert result == {"field1": "test", "field2": 123}
-
@pytest.mark.asyncio
async def test_call_plugin_api_with_proxy_success(self):
- """Test proxy API call with successful response (status 200)."""
with (
patch("framex.plugin._manager.get_api", return_value=None),
patch("framex.plugin.settings.server.enable_proxy", True),
patch("framex.plugin.get_adapter") as mock_adapter,
):
- # Simulate proxy response
- proxy_response = {"status": 200, "data": {"result": "proxy_success"}}
- mock_adapter.return_value.call_func = AsyncMock(return_value=proxy_response)
-
+ mock_adapter.return_value.call_func = AsyncMock(return_value={"status": 200, "data": {"result": "ok"}})
result = await call_plugin_api("/external/api")
-
- # Should return just the data field
- assert result == {"result": "proxy_success"}
+ assert result == {"result": "ok"}
@pytest.mark.asyncio
async def test_call_plugin_api_with_proxy_empty_data(self):
- """Test proxy API call that returns empty data with warning."""
with (
patch("framex.plugin._manager.get_api", return_value=None),
patch("framex.plugin.settings.server.enable_proxy", True),
patch("framex.plugin.get_adapter") as mock_adapter,
patch("framex.plugin.logger") as mock_logger,
):
- proxy_response = {"status": 200, "data": None}
- mock_adapter.return_value.call_func = AsyncMock(return_value=proxy_response)
-
+ mock_adapter.return_value.call_func = AsyncMock(return_value={"status": 200, "data": None})
result = await call_plugin_api("/external/api")
-
assert result is None
- # Verify warning was logged
mock_logger.opt.return_value.warning.assert_called()
@pytest.mark.asyncio
async def test_call_plugin_api_with_proxy_error_status(self):
- """Test proxy API call with non-200 status logs error."""
with (
patch("framex.plugin._manager.get_api", return_value=None),
patch("framex.plugin.settings.server.enable_proxy", True),
patch("framex.plugin.get_adapter") as mock_adapter,
patch("framex.plugin.logger") as mock_logger,
):
- proxy_response = {"status": 500, "data": None}
- mock_adapter.return_value.call_func = AsyncMock(return_value=proxy_response)
+ mock_adapter.return_value.call_func = AsyncMock(return_value={"status": 500, "data": None})
with pytest.raises(RuntimeError, match="Proxy API /external/api returned status 500"):
await call_plugin_api("/external/api")
-
- # Verify error was logged
mock_logger.opt.return_value.error.assert_called()
@pytest.mark.asyncio
async def test_call_plugin_api_not_found_no_proxy(self):
- """Test API not found when proxy is disabled raises RuntimeError."""
with (
- patch("framex.plugin._manager.get_api", return_value=None),
patch("framex.plugin.settings.server.enable_proxy", False),
pytest.raises(RuntimeError, match="API test_api is not found"),
):
@@ -139,167 +75,131 @@ async def test_call_plugin_api_not_found_no_proxy(self):
@pytest.mark.asyncio
async def test_call_plugin_api_not_found_non_slash_with_proxy(self):
- """Test non-slash prefixed API not found with proxy enabled raises error."""
with (
- patch("framex.plugin._manager.get_api", return_value=None),
patch("framex.plugin.settings.server.enable_proxy", True),
pytest.raises(RuntimeError, match="API test_api is not found"),
):
await call_plugin_api("test_api")
@pytest.mark.asyncio
- async def test_call_plugin_api_with_dict_to_basemodel_conversion(self):
- """Test automatic conversion of dict parameters to BaseModel."""
- api = PluginApi(api="test_api", deployment_name="test_deployment", params=[("model_param", SampleModel)])
-
- with (
- patch("framex.plugin._manager.get_api", return_value=api),
- patch("framex.plugin.get_adapter") as mock_adapter,
- ):
- mock_adapter.return_value.call_func = AsyncMock(return_value="success")
-
- # Pass dict that should be converted to SampleModel
- result = await call_plugin_api("test_api", model_param={"field1": "test", "field2": 456})
-
- # Verify the call was made and dict was converted
- assert result == "success"
- call_args = mock_adapter.return_value.call_func.call_args
- assert isinstance(call_args[1]["model_param"], SampleModel)
+ async def test_call_plugin_api_requires_remote_apis_in_context(self):
+ token = set_current_remote_apis({})
+ try:
+ with pytest.raises(RuntimeError, match="not declared in current plugin remote_apis"):
+ await call_plugin_api("test_api")
+ finally:
+ reset_current_remote_apis(token)
@pytest.mark.asyncio
- async def test_call_plugin_api_with_interval_apis(self):
- """Test using interval_apis parameter to override manager lookup."""
+ async def test_call_plugin_api_with_remote_apis_dict_to_basemodel_conversion(self):
+ api = PluginApi(api="test_api", deployment_name="test_deployment", params=[("model_param", SampleModel)])
+ token = set_current_remote_apis({"test_api": api})
+ try:
+ with patch("framex.plugin.get_adapter") as mock_adapter:
+ mock_adapter.return_value.call_func = AsyncMock(return_value="success")
+ result = await call_plugin_api("test_api", model_param={"field1": "test", "field2": 456})
+ assert result == "success"
+ call_args = mock_adapter.return_value.call_func.call_args
+ assert isinstance(call_args[1]["model_param"], SampleModel)
+ finally:
+ reset_current_remote_apis(token)
+
+ @pytest.mark.asyncio
+ async def test_call_plugin_api_with_remote_apis(self):
api = PluginApi(api="test_api", deployment_name="test_deployment")
- interval_apis = {"test_api": api}
-
- with (
- patch("framex.plugin._manager.get_api") as mock_get_api,
- patch("framex.plugin.get_adapter") as mock_adapter,
- ):
- mock_adapter.return_value.call_func = AsyncMock(return_value="interval_result")
-
- token = set_current_remote_apis(interval_apis)
- try:
+ token = set_current_remote_apis({"test_api": api})
+ try:
+ with (
+ patch("framex.plugin._manager.get_api") as mock_get_api,
+ patch("framex.plugin.get_adapter") as mock_adapter,
+ ):
+ mock_adapter.return_value.call_func = AsyncMock(return_value="remote_result")
result = await call_plugin_api("test_api")
- finally:
- reset_current_remote_apis(token)
-
- # Manager get_api should not be called
- mock_get_api.assert_not_called()
- assert result == "interval_result"
+ mock_get_api.assert_not_called()
+ assert result == "remote_result"
+ finally:
+ reset_current_remote_apis(token)
@pytest.mark.asyncio
- async def test_call_plugin_api_with_interval_apis_dict_payload(self):
- """Test interval_apis values serialized through Ray can be rehydrated."""
+ async def test_call_plugin_api_with_remote_apis_dict_payload(self):
api = PluginApi(api="test_api", deployment_name="test_deployment")
- interval_apis = {"test_api": api.model_dump()}
-
- with (
- patch("framex.plugin._manager.get_api") as mock_get_api,
- patch("framex.plugin.get_adapter") as mock_adapter,
- ):
- mock_adapter.return_value.call_func = AsyncMock(return_value="interval_dict_result")
-
- token = set_current_remote_apis(interval_apis)
- try:
+ token = set_current_remote_apis({"test_api": api.model_dump()})
+ try:
+ with (
+ patch("framex.plugin._manager.get_api") as mock_get_api,
+ patch("framex.plugin.get_adapter") as mock_adapter,
+ ):
+ mock_adapter.return_value.call_func = AsyncMock(return_value="remote_dict_result")
result = await call_plugin_api("test_api")
- finally:
- reset_current_remote_apis(token)
-
- mock_get_api.assert_not_called()
- assert result == "interval_dict_result"
+ mock_get_api.assert_not_called()
+ assert result == "remote_dict_result"
+ finally:
+ reset_current_remote_apis(token)
@pytest.mark.asyncio
- async def test_call_plugin_api_uses_resolver_when_manager_misses(self):
- """Test ApiResolver fallback works when Ray workers lack manager state."""
- func_api = PluginApi(deployment_name="demo_plugin.DemoDeployment", func_name="run", call_type=ApiType.FUNC)
- http_api = PluginApi(api="/api/v1/resource_match/resource_detail", deployment_name="resource_match.Detail")
- resolver = ApiResolver(
- api_registry={
- "demo_plugin.DemoDeployment.run": func_api,
- "/api/v1/resource_match/resource_detail": http_api,
- }
- )
+ async def test_call_plugin_api_context_wrapper_preserves_async_generator(self):
+ from framex.plugin.base import BasePlugin
+ from framex.plugin.on import on_request
- with (
- patch("framex.plugin._manager.get_api", return_value=None),
- patch("framex.plugin.get_adapter") as mock_adapter,
- ):
- mock_adapter.return_value.call_func = AsyncMock(side_effect=["func_result", "http_result"])
+ class DemoPlugin(BasePlugin):
+ @on_request("/demo", stream=True)
+ async def stream_api(self):
+ yield "a"
+ yield "b"
- assert await call_plugin_api("demo_plugin.DemoDeployment.run", resolver=resolver) == "func_result"
- assert await call_plugin_api("/api/v1/resource_match/resource_detail", resolver=resolver) == "http_result"
+ plugin = DemoPlugin()
+ stream = plugin.stream_api()
+
+ assert inspect.isasyncgen(stream)
+ assert [chunk async for chunk in stream] == ["a", "b"]
@pytest.mark.asyncio
- async def test_call_plugin_api_uses_context_resolver_when_no_explicit_resolver(self):
- """Test context-bound resolver works for nested service-style calls."""
- http_api = PluginApi(api="/api/v1/fetch/fetch_company_by_dim_info", deployment_name="fetch.FetchPlugin")
- resolver = ApiResolver(api_registry={"/api/v1/fetch/fetch_company_by_dim_info": http_api})
+ async def test_call_remote_api_uses_whitelist_via_request_context(self):
+ from framex.plugin.base import BasePlugin
+ from framex.plugin.on import on_request
- with (
- patch("framex.plugin._manager.get_api", return_value=None),
- patch("framex.plugin.get_adapter") as mock_adapter,
- ):
- mock_adapter.return_value.call_func = AsyncMock(return_value="context_result")
- token = set_current_api_resolver(resolver)
- try:
- result = await call_plugin_api("/api/v1/fetch/fetch_company_by_dim_info")
- finally:
- reset_current_api_resolver(token)
+ api = PluginApi(api="test_api", deployment_name="test_deployment", params=[("model_param", SampleModel)])
- assert result == "context_result"
- assert mock_adapter.return_value.call_func.call_args[0][0] == http_api
+ class DemoPlugin(BasePlugin):
+ @on_request("/demo")
+ async def request_api(self):
+ return await self._call_remote_api("test_api", model_param={"field1": "test", "field2": 456})
- @pytest.mark.asyncio
- async def test_call_plugin_api_plugin_context_requires_remote_apis(self):
- """Test plugin-context calls do not fall back to global registry outside remote_apis."""
- http_api = PluginApi(api="/api/v1/fetch/fetch_company_by_dim_info", deployment_name="fetch.FetchPlugin")
- resolver = ApiResolver(api_registry={"/api/v1/fetch/fetch_company_by_dim_info": http_api})
-
- with patch("framex.plugin._manager.get_api", return_value=http_api):
- resolver_token = set_current_api_resolver(resolver)
- remote_token = set_current_remote_apis({})
- try:
- with pytest.raises(RuntimeError, match="not declared in current plugin remote_apis"):
- await call_plugin_api("/api/v1/fetch/fetch_company_by_dim_info")
- finally:
- reset_current_remote_apis(remote_token)
- reset_current_api_resolver(resolver_token)
+ plugin = DemoPlugin(remote_apis={"test_api": api})
+
+ with patch("framex.plugin.get_adapter") as mock_adapter:
+ mock_adapter.return_value.call_func = AsyncMock(return_value="success")
+ result = await plugin.request_api()
+ assert result == "success"
+ call_args = mock_adapter.return_value.call_func.call_args
+ assert isinstance(call_args[0][0], PluginApi)
+ assert isinstance(call_args[1]["model_param"], SampleModel)
@pytest.mark.asyncio
- async def test_call_plugin_api_context_wrapper_preserves_async_generator(self):
- """Test resolver context wrapping keeps stream handlers as async generators."""
+ async def test_call_remote_api_rejects_missing_whitelist_entry_via_request_context(self):
from framex.plugin.base import BasePlugin
from framex.plugin.on import on_request
class DemoPlugin(BasePlugin):
- @on_request("/demo", stream=True)
- async def stream_api(self):
- yield "a"
- yield "b"
+ @on_request("/demo")
+ async def request_api(self):
+ return await self._call_remote_api("test_api")
- plugin = DemoPlugin(api_registry={})
- stream = plugin.stream_api()
+ plugin = DemoPlugin(remote_apis={})
- assert inspect.isasyncgen(stream)
- assert [chunk async for chunk in stream] == ["a", "b"]
+ with pytest.raises(RuntimeError, match="not declared in current plugin remote_apis"):
+ await plugin.request_api()
@pytest.mark.asyncio
async def test_call_plugin_api_proxy_creates_correct_plugin_api(self):
- """Test that proxy fallback creates PluginApi with correct parameters."""
with (
- patch("framex.plugin._manager.get_api", return_value=None),
patch("framex.plugin.settings.server.enable_proxy", True),
patch("framex.plugin.get_adapter") as mock_adapter,
patch("framex.plugin.logger"),
):
mock_adapter.return_value.call_func = AsyncMock(return_value={"status": 200, "data": "ok"})
-
await call_plugin_api("/proxy/test")
-
- # Check the PluginApi passed to call_func
- call_args = mock_adapter.return_value.call_func.call_args
- api = call_args[0][0]
+ api = mock_adapter.return_value.call_func.call_args[0][0]
assert isinstance(api, PluginApi)
assert api.api == "/proxy/test"
assert api.deployment_name == PROXY_PLUGIN_NAME
@@ -307,47 +207,37 @@ async def test_call_plugin_api_proxy_creates_correct_plugin_api(self):
@pytest.mark.asyncio
async def test_call_plugin_api_regular_dict_result_not_proxy(self):
- """Test that regular dict results (non-proxy) are returned as-is."""
api = PluginApi(api="test_api", deployment_name="test_deployment")
-
- with (
- patch("framex.plugin._manager.get_api", return_value=api),
- patch("framex.plugin.get_adapter") as mock_adapter,
- ):
- # Regular dict result (not from proxy)
- mock_adapter.return_value.call_func = AsyncMock(return_value={"key": "value", "status": 200})
-
- result = await call_plugin_api("test_api")
-
- # Should return the entire dict, not extract "data"
- assert result == {"key": "value", "status": 200}
+ token = set_current_remote_apis({"test_api": api})
+ try:
+ with patch("framex.plugin.get_adapter") as mock_adapter:
+ mock_adapter.return_value.call_func = AsyncMock(return_value={"key": "value", "status": 200})
+ result = await call_plugin_api("test_api")
+ assert result == {"key": "value", "status": 200}
+ finally:
+ reset_current_remote_apis(token)
@pytest.mark.asyncio
async def test_call_plugin_api_with_multiple_kwargs(self):
- """Test calling API with multiple keyword arguments."""
api = PluginApi(
api="test_api", deployment_name="test_deployment", params=[("a", int), ("b", str), ("c", bool)]
)
-
- with (
- patch("framex.plugin._manager.get_api", return_value=api),
- patch("framex.plugin.get_adapter") as mock_adapter,
- ):
- mock_adapter.return_value.call_func = AsyncMock(return_value="multi_args")
-
- result = await call_plugin_api("test_api", a=1, b="test", c=True)
-
- assert result == "multi_args"
- call_kwargs = mock_adapter.return_value.call_func.call_args[1]
- assert call_kwargs["a"] == 1
- assert call_kwargs["b"] == "test"
- assert call_kwargs["c"] is True
+ token = set_current_remote_apis({"test_api": api})
+ try:
+ with patch("framex.plugin.get_adapter") as mock_adapter:
+ mock_adapter.return_value.call_func = AsyncMock(return_value="multi_args")
+ result = await call_plugin_api("test_api", a=1, b="test", c=True)
+ assert result == "multi_args"
+ call_kwargs = mock_adapter.return_value.call_func.call_args[1]
+ assert call_kwargs["a"] == 1
+ assert call_kwargs["b"] == "test"
+ assert call_kwargs["c"] is True
+ finally:
+ reset_current_remote_apis(token)
@pytest.mark.asyncio
async def test_call_plugin_api_with_proxy_missing_status(self):
- """Test proxy API call raises when status field is missing."""
with (
- patch("framex.plugin._manager.get_api", return_value=None),
patch("framex.plugin.settings.server.enable_proxy", True),
patch("framex.plugin.get_adapter") as mock_adapter,
patch("framex.plugin.logger"),
diff --git a/uv.lock b/uv.lock
index 6e6c45f..20312a6 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1,5 +1,4 @@
version = 1
-revision = 3
requires-python = ">=3.11"
resolution-markers = [
"python_full_version >= '3.13' and platform_python_implementation == 'PyPy'",
@@ -500,6 +499,7 @@ dependencies = [
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "pyjwt" },
+ { name = "python-multipart" },
{ name = "pytz" },
{ name = "sentry-sdk", extra = ["fastapi"] },
{ name = "tomli" },
@@ -543,6 +543,7 @@ requires-dist = [
{ name = "pydantic", specifier = ">=2.11.7" },
{ name = "pydantic-settings", specifier = ">=2.10.1" },
{ name = "pyjwt", specifier = ">=2.10.1" },
+ { name = "python-multipart", specifier = ">=0.0.22" },
{ name = "python-semantic-release", marker = "extra == 'release'", specifier = ">=10.2.0" },
{ name = "pytz", specifier = ">=2025.2" },
{ name = "ray", extras = ["serve"], marker = "extra == 'ray'", specifier = "==2.54.0" },
@@ -1931,6 +1932,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/34/bd/b0d440685fbcafee462bed793a74aea88541887c4c30556a55ac64914b8d/python_gitlab-6.5.0-py3-none-any.whl", hash = "sha256:494e1e8e5edd15286eaf7c286f3a06652688f1ee20a49e2a0218ddc5cc475e32", size = 144419, upload-time = "2025-10-17T21:40:01.233Z" },
]
+[[package]]
+name = "python-multipart"
+version = "0.0.22"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/94/01/979e98d542a70714b0cb2b6728ed0b7c46792b695e3eaec3e20711271ca3/python_multipart-0.0.22.tar.gz", hash = "sha256:7340bef99a7e0032613f56dc36027b959fd3b30a787ed62d310e951f7c3a3a58", size = 37612 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1b/d0/397f9626e711ff749a95d96b7af99b9c566a9bb5129b8e4c10fc4d100304/python_multipart-0.0.22-py3-none-any.whl", hash = "sha256:2b2cd894c83d21bf49d702499531c7bafd057d730c201782048f7945d82de155", size = 24579 },
+]
+
[[package]]
name = "python-semantic-release"
version = "10.5.3"