From 7bcf9c1d5ca971b3ea0a6a5f0d591974334333f0 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Fri, 27 Mar 2026 10:35:35 +0800 Subject: [PATCH 1/6] feat: remove Enum from tags type hints --- src/framex/driver/ingress.py | 3 +-- src/framex/plugin/model.py | 4 ++-- src/framex/plugin/on.py | 5 +++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/framex/driver/ingress.py b/src/framex/driver/ingress.py index fa49f64..0a39922 100644 --- a/src/framex/driver/ingress.py +++ b/src/framex/driver/ingress.py @@ -1,7 +1,6 @@ import os import re from collections.abc import Callable -from enum import Enum from typing import Any from fastapi import Depends, HTTPException, Request, Response, status @@ -73,7 +72,7 @@ def register_route( handle: Any, stream: bool = False, direct_output: bool = False, - tags: list[str | Enum] | None = None, + tags: list[str] | None = None, auth_keys: list[str] | None = None, include_in_schema: bool = True, **kwargs: Any, diff --git a/src/framex/plugin/model.py b/src/framex/plugin/model.py index 7dc0c96..c923374 100644 --- a/src/framex/plugin/model.py +++ b/src/framex/plugin/model.py @@ -1,6 +1,6 @@ from collections.abc import Callable from dataclasses import dataclass, field -from enum import Enum, StrEnum +from enum import StrEnum from types import ModuleType from typing import Any @@ -33,7 +33,7 @@ class PluginApi(BaseModel): methods: list[str] = Field(default_factory=lambda: ["POST"]) params: list[tuple[str, type[Any] | Callable[..., Any]]] = Field(default_factory=list) call_type: ApiType = ApiType.HTTP - tags: list[str | Enum] | None = None + tags: list[str] | None = None stream: bool = False raw_response: bool = False extend_kwargs: dict[str, Any] = Field(default_factory=dict) diff --git a/src/framex/plugin/on.py b/src/framex/plugin/on.py index 668f4bf..d9cd22c 100644 --- a/src/framex/plugin/on.py +++ b/src/framex/plugin/on.py @@ -40,10 +40,11 @@ def decorator(cls: type) -> type: params = extract_method_params(func) version: str = plugin.module.__plugin_meta__.version version = f"v{version}" if not version.startswith("v") else version - tags = [f"{plugin.name}({version}): {plugin.module.__plugin_meta__.description}"] if func.__tags: - tags.extend(func.__tags) + tags: list[str] = func.__tags + else: + tags = [f"{plugin.name}({version}): {plugin.module.__plugin_meta__.description}"] plugin_apis.append( PluginApi( From b7c0cad005fe0a5c9abd5ab0e5bab5193b41ce62 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Fri, 27 Mar 2026 11:02:33 +0800 Subject: [PATCH 2/6] test: improve auth tests and coroutine handling --- tests/adapter/test_ray_adapter.py | 7 ++++++- tests/driver/test_auth.py | 21 ++++++++++++--------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/tests/adapter/test_ray_adapter.py b/tests/adapter/test_ray_adapter.py index 46c8553..2503334 100644 --- a/tests/adapter/test_ray_adapter.py +++ b/tests/adapter/test_ray_adapter.py @@ -180,7 +180,12 @@ def capture_remote(func): # Now test the captured wrapper assert captured_wrapper is not None with patch("asyncio.run") as mock_asyncio_run: - mock_asyncio_run.return_value = 10 + + def _consume_coroutine(coro): + coro.close() + return 10 + + mock_asyncio_run.side_effect = _consume_coroutine result = captured_wrapper(5) mock_asyncio_run.assert_called_once() assert result == 10 diff --git a/tests/driver/test_auth.py b/tests/driver/test_auth.py index 87d7317..ea50b52 100644 --- a/tests/driver/test_auth.py +++ b/tests/driver/test_auth.py @@ -1,3 +1,4 @@ +import uuid from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import AsyncMock, Mock, patch @@ -14,6 +15,8 @@ from framex.driver.application import create_fastapi_application from framex.driver.auth import auth_jwt, authenticate, create_jwt, oauth_callback +JWT_SECRET = uuid.uuid4().hex + # ========================================================= # helpers # ========================================================= @@ -28,7 +31,7 @@ def fake_oauth(**overrides): client_secret="secret", # noqa: S106 redirect_uri="/oauth/callback", call_back_url="http://test/callback", - jwt_secret="secret", # noqa: S106 + jwt_secret=JWT_SECRET, jwt_algorithm="HS256", ) data.update(overrides) @@ -44,7 +47,7 @@ class TestCreateJWT: def test_create_jwt_success(self): with patch("framex.config.settings.auth.oauth", fake_oauth()): token = create_jwt({"username": "test"}) - decoded = jwt.decode(token, "secret", algorithms=["HS256"]) + decoded = jwt.decode(token, JWT_SECRET, algorithms=["HS256"]) assert decoded["username"] == "test" assert "iat" in decoded assert "exp" in decoded @@ -67,7 +70,7 @@ def test_returns_false_when_oauth_not_configured(self): def test_returns_false_when_no_token_cookie(self): with patch("framex.config.settings.auth.oauth") as mock_oauth: - mock_oauth.jwt_secret = "secret" # noqa: S105 + mock_oauth.jwt_secret = JWT_SECRET mock_oauth.jwt_algorithm = "HS256" req = Mock(spec=Request) @@ -77,7 +80,7 @@ def test_returns_false_when_no_token_cookie(self): def test_returns_true_when_token_is_valid(self): with patch("framex.config.settings.auth.oauth") as mock_oauth: - mock_oauth.jwt_secret = "secret" # noqa: S105 + mock_oauth.jwt_secret = JWT_SECRET mock_oauth.jwt_algorithm = "HS256" now = datetime.now(UTC) @@ -87,7 +90,7 @@ def test_returns_true_when_token_is_valid(self): "iat": int(now.timestamp()), "exp": int((now + timedelta(hours=1)).timestamp()), }, - "secret", + JWT_SECRET, algorithm="HS256", ) @@ -98,7 +101,7 @@ def test_returns_true_when_token_is_valid(self): def test_returns_false_when_token_is_invalid(self): with patch("framex.config.settings.auth.oauth") as mock_oauth: - mock_oauth.jwt_secret = "secret" # noqa: S105 + mock_oauth.jwt_secret = JWT_SECRET mock_oauth.jwt_algorithm = "HS256" req = Mock(spec=Request) @@ -108,7 +111,7 @@ def test_returns_false_when_token_is_invalid(self): def test_returns_false_when_token_is_expired(self): with patch("framex.config.settings.auth.oauth") as mock_oauth: - mock_oauth.jwt_secret = "secret" # noqa: S105 + mock_oauth.jwt_secret = JWT_SECRET mock_oauth.jwt_algorithm = "HS256" now = datetime.now(UTC) @@ -118,7 +121,7 @@ def test_returns_false_when_token_is_expired(self): "iat": int((now - timedelta(days=2)).timestamp()), "exp": int((now - timedelta(days=1)).timestamp()), }, - "secret", + JWT_SECRET, algorithm="HS256", ) @@ -213,7 +216,7 @@ def test_docs_accessible_with_valid_jwt(self): "iat": int(now.timestamp()), "exp": int((now + timedelta(hours=1)).timestamp()), }, - "secret", + JWT_SECRET, algorithm="HS256", ) client.cookies.set("token", token) From 6a5c129197b239a7357438a09e00ee0ee00fb92e Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Sat, 28 Mar 2026 15:56:40 +0800 Subject: [PATCH 3/6] feat: add multipart/form-data support for proxy plugin --- pyproject.toml | 1 + pytest.ini | 2 +- src/framex/driver/ingress.py | 42 ++++++++-- src/framex/plugin/__init__.py | 4 +- src/framex/plugins/proxy/__init__.py | 121 ++++++++++++++++++++++----- src/framex/plugins/proxy/builder.py | 34 +++++++- tests/api/test_proxy.py | 40 +++++++++ tests/consts.py | 51 ++++++++++- tests/mock.py | 15 ++++ uv.lock | 12 ++- 10 files changed, 287 insertions(+), 35 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2b9d41..59a1e96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "pydantic>=2.11.7", "pydantic-settings>=2.10.1", "pyjwt>=2.10.1", + "python-multipart>=0.0.22", "pytz>=2025.2", "sentry-sdk[fastapi]>=2.33.0", "tomli>=2.2.1", diff --git a/pytest.ini b/pytest.ini index 261728c..4ee217e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -24,7 +24,7 @@ env = server__enable_proxy=true load_builtin_plugins=["echo","proxy","invoker"] plugins__proxy__proxy_urls=["http://localhost:9527"] - plugins__proxy__white_list=["/api/v1/proxy/mock/info","/proxy/mock/get","/proxy/mock/post","/proxy/mock/post_model","/proxy/mock/auth/*"] + plugins__proxy__white_list=["/api/v1/proxy/mock/info","/proxy/mock/get","/proxy/mock/post","/proxy/mock/post_model","/proxy/mock/upload","/proxy/mock/auth/*"] plugins__proxy__auth__rules={{"/proxy/mock/auth/*":["i_am_proxy_general_auth_keys"],"/api/v1/openapi.json":["i_am_proxy_docs_auth_keys"],"/proxy/mock/auth/sget":["i_am_proxy_special_auth_keys"],"/api/v1/proxy/remote":["i_am_proxy_func_auth_keys"]}} plugins__proxy__proxy_functions={{"http://localhost:9527":["tests.test_plugins.remote_exchange_key_value"]}} ; plugins__proxy__force_stream_apis=[] diff --git a/src/framex/driver/ingress.py b/src/framex/driver/ingress.py index 0a39922..a2ddb85 100644 --- a/src/framex/driver/ingress.py +++ b/src/framex/driver/ingress.py @@ -1,12 +1,13 @@ +import inspect import os import re from collections.abc import Callable -from typing import Any +from typing import Annotated, Any, get_args, get_origin from fastapi import Depends, HTTPException, Request, Response, status from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute -from pydantic import create_model +from pydantic import BaseModel from starlette.routing import Route from framex.adapter import get_adapter @@ -21,6 +22,17 @@ app = create_fastapi_application() +def _unwrap_annotation(annotation: Any) -> Any: + if get_origin(annotation) is Annotated: + return get_args(annotation)[0] + return annotation + + +def _is_basemodel_annotation(annotation: Any) -> bool: + annotation = _unwrap_annotation(annotation) + return isinstance(annotation, type) and issubclass(annotation, BaseModel) + + @app.get("/health") async def health() -> str: return "ok" @@ -100,9 +112,8 @@ def register_route( return False if (not path) or (not methods): raise RuntimeError(f"Api({path}) or methods({methods}) is empty") - Model: BaseModel = create_model(f"{func_name}_InputModel", **{name: (tp, ...) for name, tp in params}) # type:ignore # noqa - async def route_handler(response: Response, model: Model = Depends()) -> Any: # type: ignore [valid-type] + async def route_handler(response: Response, **request_kwargs: Any) -> Any: c_handle = getattr(handle, func_name) if not c_handle: raise RuntimeError( @@ -110,13 +121,32 @@ async def route_handler(response: Response, model: Model = Depends()) -> Any: # ) response.headers["X-Raw-Output"] = str(direct_output) if stream: - gen = adapter._stream_call(c_handle, **(model.__dict__)) + gen = adapter._stream_call(c_handle, **request_kwargs) return StreamingResponse( # type: ignore gen, media_type="text/event-stream", ) - return await adapter._acall(c_handle, **model.__dict__) # type: ignore + return await adapter._acall(c_handle, **request_kwargs) # type: ignore + + route_handler.__signature__ = inspect.Signature( # type: ignore + [ + inspect.Parameter( + "response", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=Response, + ), + *[ + inspect.Parameter( + name, + inspect.Parameter.KEYWORD_ONLY, + annotation=tp, + default=inspect.Parameter.empty, + ) + for name, tp in params + ], + ] + ) # Inject auth dependency if needed dependencies = [] diff --git a/src/framex/plugin/__init__.py b/src/framex/plugin/__init__.py index 728afa6..3117145 100644 --- a/src/framex/plugin/__init__.py +++ b/src/framex/plugin/__init__.py @@ -63,8 +63,8 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]: call_type=ApiType.PROXY, ) logger.opt(colors=True).warning( - f"Api({api_name}) not found, " - f"plugin({dep.deployment}) will " + f"Api({api_name}) not found, " + f"plugin({dep.deployment}) will " f"use proxy plugin({PROXY_PLUGIN_NAME}) to transfer!" ) else: # pragma: no cover diff --git a/src/framex/plugins/proxy/__init__.py b/src/framex/plugins/proxy/__init__.py index 509edf2..90fea99 100644 --- a/src/framex/plugins/proxy/__init__.py +++ b/src/framex/plugins/proxy/__init__.py @@ -15,7 +15,14 @@ from framex.plugin import BasePlugin, PluginApi, PluginMetadata, on_register from framex.plugin.model import ApiType from framex.plugin.on import on_request -from framex.plugins.proxy.builder import create_pydantic_model, format_proxy_params, type_map +from framex.plugins.proxy.builder import ( + create_pydantic_model, + format_proxy_params, + is_upload_annotation, + resolve_annotation, + to_multipart_annotation, + type_map, +) from framex.plugins.proxy.config import ProxyPluginConfig, settings from framex.plugins.proxy.model import ProxyFunc, ProxyFuncHttpBody from framex.utils import cache_decode, cache_encode, shorten_str @@ -48,6 +55,7 @@ async def on_start(self) -> None: return for url in settings.proxy_url_list: + logger.info(f"Try to parse openapi docs from {url}") await self._parse_openai_docs(url) if settings.proxy_functions: @@ -96,6 +104,7 @@ async def _parse_openai_docs(self, url: str) -> None: headers = None for method, body in details.items(): + func_name = body.get("operationId") # Process request parameters params: list[tuple[str, Any]] = [ (name, c_type) @@ -104,26 +113,58 @@ async def _parse_openai_docs(self, url: str) -> None: and (typ := param.get("schema").get("type")) and (c_type := type_map.get(typ)) ] + body_param_names: set[str] = set() + file_param_names: set[str] = set() # Process request body if request_body := body.get("requestBody"): - schema_name = ( - request_body.get("content", {}) - .get("application/json", {}) - .get("schema", {}) - .get("$ref", "") - .rsplit("/", 1)[-1] - ) - if not (model_schema := components.get(schema_name)): # pragma: no cover - raise ValueError(f"Schema '{schema_name}' not found in components.") - - Model = create_pydantic_model(schema_name, model_schema, components) # noqa - params.append(("model", Model)) + body_content = request_body.get("content", {}) + if "application/json" in body_content: + content_type = "application/json" + elif "multipart/form-data" in body_content: + content_type = "multipart/form-data" + else: + logger.opt(colors=True).error( + f"Failed to proxy api({method}) {url}{path}, unsupported content type: ${body_content.keys()}" + ) + continue + + schema = body_content.get(content_type, {}).get("schema", {}) + if (schema_name := schema.get("$ref", {}).rsplit("/", 1)[-1]) and not ( + model_schema := components.get(schema_name) + ): + logger.opt(colors=True).error( + f"Failed to proxy api({method}) {url}{path}, schema '{schema_name}' not found in components" + ) + continue + + if content_type == "application/json": + Model = create_pydantic_model(schema_name, model_schema, components) # noqa + params.append(("model", Model)) + body_param_names.add("model") + elif content_type == "multipart/form-data": + for field_name, prop_schema in model_schema.get("properties", {}).items(): + annotation = resolve_annotation(prop_schema, components) + params.append((field_name, to_multipart_annotation(annotation))) + body_param_names.add(field_name) + if is_upload_annotation(annotation): + file_param_names.add(field_name) + else: + logger.opt(colors=True).error( + f"Failed to proxy api({method}) {url}{path}, unsupported content type: ${body_content.keys()}" + ) + continue logger.opt(colors=True).trace(f"Found proxy api({method}) {url}{path}") - func_name = body.get("operationId") is_stream = path in settings.force_stream_apis func = self._create_dynamic_method( - func_name, method, params, f"{url}{path}", stream=is_stream, headers=headers + func_name, + method, + params, + f"{url}{path}", + body_param_names=body_param_names, + file_param_names=file_param_names, + stream=is_stream, + headers=headers, ) setattr(self, func_name, func) @@ -196,6 +237,7 @@ async def __call__(self, proxy_path: str, **kwargs: Any) -> Any: return await func(**kwargs) raise RuntimeError(f"api({proxy_path}) not found") + # @logger.catch async def fetch_response( self, stream: bool = False, @@ -226,12 +268,16 @@ def _create_dynamic_method( method: str, params: list[tuple[str, type]], url: str, + body_param_names: set[str] | None = None, + file_param_names: set[str] | None = None, stream: bool = False, headers: dict[str, str] | None = None, ) -> Callable[..., Any]: # Build a Pydantic request model (for data validation) model_name = f"{func_name.title()}_RequestModel" RequestModel = create_model(model_name, **{k: (t, ...) for k, t in params}) # type: ignore # noqa + body_param_names = body_param_names or set() + file_param_names = file_param_names or set() # Construct dynamic methods async def dynamic_method(**kwargs: Any) -> AsyncGenerator[str, None] | dict[str, Any] | str: @@ -240,19 +286,50 @@ async def dynamic_method(**kwargs: Any) -> AsyncGenerator[str, None] | dict[str, validated = RequestModel(**kwargs) # Type Validation query = {} json_body = None + form_body = {} + files = [] for field_name, value in validated: - if isinstance(value, BaseModel): + if field_name in body_param_names: + if field_name == "model" and isinstance(value, BaseModel): + json_body = value.model_dump() + elif field_name in file_param_names: + upload_values = value if isinstance(value, list) else [value] + for upload in upload_values: + upload.file.seek(0) + files.append( + ( + field_name, + ( + upload.filename or field_name, + upload.file, + upload.content_type or "application/octet-stream", + ), + ) + ) + else: + form_body[field_name] = value + elif isinstance(value, BaseModel): json_body = value.model_dump() else: query[field_name] = value try: + request_kwargs: dict[str, Any] = { + "stream": stream, + "method": method.upper(), + "url": url, + "params": query, + "headers": headers, + } + if method.upper() != "GET": + if files or form_body: + if form_body: + request_kwargs["data"] = form_body + if files: + request_kwargs["files"] = files + elif json_body is not None: + request_kwargs["json"] = json_body return await self.fetch_response( - stream=stream, - method=method.upper(), - url=url, - params=query, - json=json_body if method.upper() != "GET" else None, - headers=headers, + **request_kwargs, ) except Exception as e: logger.opt(exception=e, colors=True).error(f"Error calling proxy api({method}) {url}: {e}") diff --git a/src/framex/plugins/proxy/builder.py b/src/framex/plugins/proxy/builder.py index 29657c0..5177415 100644 --- a/src/framex/plugins/proxy/builder.py +++ b/src/framex/plugins/proxy/builder.py @@ -1,5 +1,6 @@ -from typing import Any, Union +from typing import Annotated, Any, Union, get_args, get_origin +from fastapi import File, Form, UploadFile from pydantic import BaseModel, create_model from framex.plugins.proxy.model import ProxyFuncHttpBody @@ -16,7 +17,27 @@ "object": dict, "null": None, } -from typing import get_args, get_origin + + +def unwrap_annotation(annotation: Any) -> Any: + if get_origin(annotation) is Annotated: + return get_args(annotation)[0] + return annotation + + +def is_upload_annotation(annotation: Any) -> bool: + annotation = unwrap_annotation(annotation) + origin = get_origin(annotation) + if origin is list: + args = get_args(annotation) + return len(args) == 1 and is_upload_annotation(args[0]) + return annotation is UploadFile + + +def to_multipart_annotation(annotation: Any) -> Any: + if is_upload_annotation(annotation): + return Annotated[annotation, File(...)] + return Annotated[annotation, Form(...)] def resolve_annotation( @@ -54,9 +75,18 @@ def resolve_annotation( if "$ref" in prop_schema: item_type = resolve_annotation(prop_schema, components) else: + if prop_schema.get("type") == "string" and ( + prop_schema.get("contentMediaType") == "application/octet-stream" + or prop_schema.get("format") == "binary" + ): + return list[UploadFile] # type: ignore [valid-type] item_type = type_map.get(prop_schema.get("type", "string"), str) return list[item_type] # type: ignore [valid-type] if typ: + if typ == "string" and ( + prop_schema.get("contentMediaType") == "application/octet-stream" or prop_schema.get("format") == "binary" + ): + return UploadFile return type_map.get(typ, str) raise RuntimeError(f"Unsupported prop_schema: {prop_schema}") diff --git a/tests/api/test_proxy.py b/tests/api/test_proxy.py index 5ce5049..1e5b160 100644 --- a/tests/api/test_proxy.py +++ b/tests/api/test_proxy.py @@ -29,6 +29,46 @@ def test_get_proxy_post_model(client: TestClient): assert res == {"method": "POST", "body": data} +def test_get_proxy_upload(client: TestClient): + params = {"message": "hello world"} + data = {"note": "upload note"} + files = { + "ppt_file": ( + "demo.pptx", + b"ppt-content", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ), + "txt_file": ("demo.txt", b"txt-content", "text/plain"), + } + res = client.post("/proxy/mock/upload", params=params, data=data, files=files).json() + assert res == { + "method": "POST", + "params": params, + "body": data, + "files": [ + { + "field": "ppt_file", + "filename": "demo.pptx", + "content_type": "application/vnd.openxmlformats-officedocument.presentationml.presentation", + }, + { + "field": "txt_file", + "filename": "demo.txt", + "content_type": "text/plain", + }, + ], + } + + +def test_get_proxy_upload_openapi(client: TestClient): + data = client.get("/api/v1/openapi.json").json() + post = data["paths"]["/proxy/mock/upload"]["post"] + ref = post["requestBody"]["content"]["multipart/form-data"]["schema"]["$ref"].split("/")[-1] + schema = data["components"]["schemas"][ref] + assert schema["properties"]["ppt_file"]["format"] == "binary" + assert schema["properties"]["txt_file"]["format"] == "binary" + + def test_get_proxy_black_get(client: TestClient): res = client.get("/proxy/mock/black_get").json() assert res["status"] == 404 diff --git a/tests/consts.py b/tests/consts.py index ff78640..44cb188 100644 --- a/tests/consts.py +++ b/tests/consts.py @@ -70,6 +70,34 @@ }, } }, + "/proxy/mock/upload": { + "post": { + "tags": ["proxy"], + "summary": "Proxy Mock Upload", + "operationId": "proxy_mock_upload", + "parameters": [ + { + "name": "message", + "in": "query", + "required": True, + "schema": { + "type": "string", + "title": "Message", + }, + } + ], + "requestBody": { + "required": True, + "content": {"multipart/form-data": {"schema": {"$ref": "#/components/schemas/MockUploadBody"}}}, + }, + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {"type": "object"}}}, + } + }, + } + }, "/proxy/mock/black_get": { "get": { "tags": ["proxy"], @@ -197,7 +225,28 @@ "default": "default", }, }, - } + }, + "MockUploadBody": { + "title": "MockUploadBody", + "type": "object", + "required": ["note", "ppt_file", "txt_file"], + "properties": { + "note": { + "type": "string", + "title": "Note", + }, + "ppt_file": { + "type": "string", + "contentMediaType": "application/octet-stream", + "title": "Ppt File", + }, + "txt_file": { + "type": "string", + "format": "binary", + "title": "Txt File", + }, + }, + }, } }, } diff --git a/tests/mock.py b/tests/mock.py index 22e8eae..5cbfb04 100644 --- a/tests/mock.py +++ b/tests/mock.py @@ -26,6 +26,7 @@ async def mock_get(_, url: str, *__, **kwargs: Any): async def mock_request(_, method: str, url: str, **kwargs: Any): params = kwargs.get("params") body = kwargs.get("json") or kwargs.get("data") + files = kwargs.get("files") headers = kwargs.get("headers", {}) resp = MagicMock() @@ -45,6 +46,20 @@ async def mock_request(_, method: str, url: str, **kwargs: Any): "method": "POST", "body": body, } + elif url.endswith("/proxy/mock/upload") and method == "POST": + resp.json.return_value = { + "method": "POST", + "params": params, + "body": body, + "files": [ + { + "field": field_name, + "filename": file_info[0], + "content_type": file_info[2], + } + for field_name, file_info in (files or []) + ], + } elif url.endswith("/proxy/mock/info") and method == "GET": resp.json.return_value = { "info": "i_am_mock_proxy_info", diff --git a/uv.lock b/uv.lock index 6e6c45f..20312a6 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 3 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", @@ -500,6 +499,7 @@ dependencies = [ { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyjwt" }, + { name = "python-multipart" }, { name = "pytz" }, { name = "sentry-sdk", extra = ["fastapi"] }, { name = "tomli" }, @@ -543,6 +543,7 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.11.7" }, { name = "pydantic-settings", specifier = ">=2.10.1" }, { name = "pyjwt", specifier = ">=2.10.1" }, + { name = "python-multipart", specifier = ">=0.0.22" }, { name = "python-semantic-release", marker = "extra == 'release'", specifier = ">=10.2.0" }, { name = "pytz", specifier = ">=2025.2" }, { name = "ray", extras = ["serve"], marker = "extra == 'ray'", specifier = "==2.54.0" }, @@ -1931,6 +1932,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/34/bd/b0d440685fbcafee462bed793a74aea88541887c4c30556a55ac64914b8d/python_gitlab-6.5.0-py3-none-any.whl", hash = "sha256:494e1e8e5edd15286eaf7c286f3a06652688f1ee20a49e2a0218ddc5cc475e32", size = 144419, upload-time = "2025-10-17T21:40:01.233Z" }, ] +[[package]] +name = "python-multipart" +version = "0.0.22" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/01/979e98d542a70714b0cb2b6728ed0b7c46792b695e3eaec3e20711271ca3/python_multipart-0.0.22.tar.gz", hash = "sha256:7340bef99a7e0032613f56dc36027b959fd3b30a787ed62d310e951f7c3a3a58", size = 37612 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1b/d0/397f9626e711ff749a95d96b7af99b9c566a9bb5129b8e4c10fc4d100304/python_multipart-0.0.22-py3-none-any.whl", hash = "sha256:2b2cd894c83d21bf49d702499531c7bafd057d730c201782048f7945d82de155", size = 24579 }, +] + [[package]] name = "python-semantic-release" version = "10.5.3" From 84f3f60cfb3ea8fd9520b37ebc3aefc9376a26d9 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Sat, 28 Mar 2026 16:37:22 +0800 Subject: [PATCH 4/6] refactor: simplify API resolver and remove ApiResolver class --- src/framex/plugin/__init__.py | 33 +--- src/framex/plugin/base.py | 28 +-- src/framex/plugin/load.py | 1 - src/framex/plugin/on.py | 1 - src/framex/plugin/resolver.py | 64 +------ tests/test_plugin.py | 326 +++++++++++----------------------- 6 files changed, 127 insertions(+), 326 deletions(-) diff --git a/src/framex/plugin/__init__.py b/src/framex/plugin/__init__.py index 3117145..11f8d23 100644 --- a/src/framex/plugin/__init__.py +++ b/src/framex/plugin/__init__.py @@ -10,18 +10,11 @@ from framex.log import logger from framex.plugin.manage import _manager from framex.plugin.model import Plugin, PluginApi -from framex.plugin.resolver import ( - ApiResolver, - _set_default_api_resolver, - get_current_api_resolver, - get_current_remote_apis, - get_default_api_resolver, -) +from framex.plugin.resolver import coerce_plugin_api, get_current_remote_apis C = TypeVar("C", bound=BaseModel) _current_plugin: ContextVar[Optional["Plugin"]] = ContextVar("_current_plugin", default=None) -_set_default_api_resolver(ApiResolver(manager=_manager)) def get_plugin(plugin_id: str) -> Plugin | None: @@ -47,7 +40,6 @@ def check_plugin_config_exists(plugin_name: str) -> bool: @logger.catch() def init_all_deployments(enable_proxy: bool) -> list[Any]: deployments = [] - all_apis = {**_manager.all_plugin_apis[ApiType.FUNC], **_manager.all_plugin_apis[ApiType.HTTP]} for plugin in get_loaded_plugins(): for dep in plugin.deployments: remote_apis = { @@ -74,7 +66,6 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]: deployment = get_adapter().bind( dep.deployment, remote_apis=remote_apis, - api_registry=all_apis, config=plugin.config, ) @@ -83,20 +74,12 @@ def init_all_deployments(enable_proxy: bool) -> list[Any]: return deployments -def _resolve_plugin_api( - api_name: str | PluginApi, - resolver: ApiResolver | None = None, -) -> tuple[PluginApi, bool]: +def _resolve_plugin_api(api_name: str | PluginApi) -> tuple[PluginApi, bool]: current_remote_apis = get_current_remote_apis() if isinstance(api_name, PluginApi): return api_name, api_name.call_type == ApiType.PROXY - active_resolver = resolver or get_current_api_resolver() or get_default_api_resolver() - if current_remote_apis is not None: - api = active_resolver.coerce_plugin_api(current_remote_apis.get(api_name)) - else: - api = active_resolver.resolve(api_name, None) - + api = coerce_plugin_api(current_remote_apis.get(api_name)) if current_remote_apis is not None else None if api is None: if current_remote_apis is not None: raise RuntimeError( @@ -149,19 +132,15 @@ def _unwrap_plugin_call_result(api_name: str | PluginApi, result: Any, use_proxy res = result.get("data") status = result.get("status") if status not in settings.server.legal_proxy_code: - logger.opt(colors=True).error(f"<>Proxy API {api_name} call illegal: {result}") + logger.opt(colors=True).error(f"Proxy API {api_name} call illegal: {result}") raise RuntimeError(f"Proxy API {api_name} returned status {status}") if res is None: logger.opt(colors=True).warning(f"API {api_name} returned empty data") return res -async def call_plugin_api( - api_name: str | PluginApi, - resolver: ApiResolver | None = None, - **kwargs: Any, -) -> Any: - api, use_proxy = _resolve_plugin_api(api_name, resolver=resolver) +async def call_plugin_api(api_name: str | PluginApi, **kwargs: Any) -> Any: + api, use_proxy = _resolve_plugin_api(api_name) normalized_kwargs = _normalize_plugin_call_kwargs(api, kwargs) result = await get_adapter().call_func(api, **normalized_kwargs) return _unwrap_plugin_call_result(api_name, result, use_proxy) diff --git a/src/framex/plugin/base.py b/src/framex/plugin/base.py index 0042e67..a784823 100644 --- a/src/framex/plugin/base.py +++ b/src/framex/plugin/base.py @@ -1,19 +1,11 @@ import inspect -from collections.abc import Mapping from functools import wraps from typing import Any, final from framex.config import settings from framex.log import setup_logger from framex.plugin import call_plugin_api -from framex.plugin.model import PluginApi -from framex.plugin.resolver import ( - ApiResolver, - reset_current_api_resolver, - reset_current_remote_apis, - set_current_api_resolver, - set_current_remote_apis, -) +from framex.plugin.resolver import reset_current_remote_apis, set_current_remote_apis class BasePlugin: @@ -21,10 +13,8 @@ class BasePlugin: def __init__(self, **kwargs: Any) -> None: setup_logger() - self.remote_apis: dict[str, PluginApi] = kwargs.get("remote_apis", {}) - self.api_registry: Mapping[str, PluginApi] = kwargs.get("api_registry", {}) - self.api_resolver = ApiResolver(api_registry=self.api_registry) - self._bind_api_resolver_context() + self.remote_apis = kwargs.get("remote_apis", {}) + self._bind_remote_api_context() if settings.server.use_ray: import asyncio @@ -33,26 +23,24 @@ def __init__(self, **kwargs: Any) -> None: async def on_start(self) -> None: pass - def _bind_api_resolver_context(self) -> None: + def _bind_remote_api_context(self) -> None: for name, func in inspect.getmembers(type(self), predicate=callable): if not getattr(func, "_on_request", False): continue bound = getattr(self, name) - setattr(self, name, self._wrap_with_api_resolver(bound)) + setattr(self, name, self._wrap_with_remote_api_context(bound)) - def _wrap_with_api_resolver(self, func: Any) -> Any: + def _wrap_with_remote_api_context(self, func: Any) -> Any: if inspect.isasyncgenfunction(func): @wraps(func) async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any: - resolver_token = set_current_api_resolver(self.api_resolver) remote_token = set_current_remote_apis(self.remote_apis) try: async for chunk in func(*args, **kwargs): yield chunk finally: reset_current_remote_apis(remote_token) - reset_current_api_resolver(resolver_token) return async_gen_wrapper @@ -60,25 +48,21 @@ async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any: @wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> Any: - resolver_token = set_current_api_resolver(self.api_resolver) remote_token = set_current_remote_apis(self.remote_apis) try: return await func(*args, **kwargs) finally: reset_current_remote_apis(remote_token) - reset_current_api_resolver(resolver_token) return async_wrapper @wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - resolver_token = set_current_api_resolver(self.api_resolver) remote_token = set_current_remote_apis(self.remote_apis) try: return func(*args, **kwargs) finally: reset_current_remote_apis(remote_token) - reset_current_api_resolver(resolver_token) return sync_wrapper diff --git a/src/framex/plugin/load.py b/src/framex/plugin/load.py index d260f49..50bf3c9 100644 --- a/src/framex/plugin/load.py +++ b/src/framex/plugin/load.py @@ -56,7 +56,6 @@ async def register_proxy_func(func: Callable) -> None: ) await call_plugin_api( api_reg, - None, func_name=full_func_name, func_callable=_PROXY_REGISTRY[full_func_name], ) diff --git a/src/framex/plugin/on.py b/src/framex/plugin/on.py index d9cd22c..f42c60a 100644 --- a/src/framex/plugin/on.py +++ b/src/framex/plugin/on.py @@ -178,7 +178,6 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: ) res = await call_plugin_api( api_call, - None, func_name=cache_encode(full_func_name), data=cache_encode(data=kwargs), ) diff --git a/src/framex/plugin/resolver.py b/src/framex/plugin/resolver.py index 1e5ddb8..b68e59f 100644 --- a/src/framex/plugin/resolver.py +++ b/src/framex/plugin/resolver.py @@ -1,71 +1,21 @@ from collections.abc import Mapping from contextvars import ContextVar -from typing import Any, Protocol +from typing import Any from framex.plugin.model import PluginApi -class SupportsApiLookup(Protocol): - def get_api(self, api_name: str) -> PluginApi | None: ... +def coerce_plugin_api(api: PluginApi | dict[str, Any] | None) -> PluginApi | None: + if api is None or isinstance(api, PluginApi): + return api + if isinstance(api, dict): + return PluginApi.model_validate(api) + return None -class ApiResolver: - def __init__( - self, - manager: SupportsApiLookup | None = None, - api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None, - ) -> None: - self._manager = manager - self._api_registry = api_registry or {} - - @staticmethod - def coerce_plugin_api(api: PluginApi | dict[str, Any] | None) -> PluginApi | None: - if api is None or isinstance(api, PluginApi): - return api - if isinstance(api, dict): - return PluginApi.model_validate(api) - return None - - def resolve( - self, - api_name: str, - api_registry: Mapping[str, PluginApi | dict[str, Any]] | None = None, - ) -> PluginApi | None: - if api_registry is not None and (api := self.coerce_plugin_api(api_registry.get(api_name))): - return api - if self._manager and (api := self._manager.get_api(api_name)): - return api - return self.coerce_plugin_api(self._api_registry.get(api_name)) - - -_current_api_resolver: ContextVar[ApiResolver | None] = ContextVar("_current_api_resolver", default=None) _current_remote_apis: ContextVar[Mapping[str, PluginApi | dict[str, Any]] | None] = ContextVar( "_current_remote_apis", default=None ) -_default_api_resolver: ApiResolver | None = None - - -def get_current_api_resolver() -> ApiResolver | None: - return _current_api_resolver.get() - - -def get_default_api_resolver() -> ApiResolver: - if _default_api_resolver is None: - raise RuntimeError("Default API resolver is not configured") - return _default_api_resolver - - -def _set_default_api_resolver(resolver: ApiResolver) -> None: - global _default_api_resolver - _default_api_resolver = resolver - - -def set_current_api_resolver(resolver: ApiResolver | None) -> Any: - return _current_api_resolver.set(resolver) - - -def reset_current_api_resolver(token: Any) -> None: - _current_api_resolver.reset(token) def get_current_remote_apis() -> Mapping[str, PluginApi | dict[str, Any]] | None: diff --git a/tests/test_plugin.py b/tests/test_plugin.py index b2ff485..2b27031 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -8,17 +8,10 @@ from framex.consts import PROXY_PLUGIN_NAME, VERSION from framex.plugin import call_plugin_api from framex.plugin.model import ApiType, PluginApi -from framex.plugin.resolver import ( - ApiResolver, - reset_current_api_resolver, - reset_current_remote_apis, - set_current_api_resolver, - set_current_remote_apis, -) +from framex.plugin.resolver import reset_current_remote_apis, set_current_remote_apis def test_get_plugin(): - # check simple plugin plugin = framex.get_plugin("export") assert plugin assert plugin.version == VERSION @@ -30,108 +23,51 @@ def test_get_plugin(): class SampleModel(BaseModel): - """Sample model for testing parameter conversion.""" - field1: str field2: int class TestCallPluginApi: - """Comprehensive tests for call_plugin_api function with proxy handling.""" - - @pytest.mark.asyncio - async def test_call_plugin_api_with_existing_api(self): - """Test calling an API that exists in the manager.""" - # Setup - api = PluginApi(api="test_api", deployment_name="test_deployment", params=[("param1", str), ("param2", int)]) - - with ( - patch("framex.plugin._manager.get_api", return_value=api), - patch("framex.plugin.get_adapter") as mock_adapter, - ): - mock_adapter.return_value.call_func = AsyncMock(return_value="test_result") - - # Execute - result = await call_plugin_api("test_api", param1="value1", param2=42) - - # Assert - assert result == "test_result" - mock_adapter.return_value.call_func.assert_called_once() - - @pytest.mark.asyncio - async def test_call_plugin_api_with_basemodel_result(self): - """Test that BaseModel results are converted to dict with aliases.""" - api = PluginApi(api="test_api", deployment_name="test_deployment") - model_result = SampleModel(field1="test", field2=123) - - with ( - patch("framex.plugin._manager.get_api", return_value=api), - patch("framex.plugin.get_adapter") as mock_adapter, - ): - mock_adapter.return_value.call_func = AsyncMock(return_value=model_result) - - result = await call_plugin_api("test_api") - - assert isinstance(result, dict) - assert result == {"field1": "test", "field2": 123} - @pytest.mark.asyncio async def test_call_plugin_api_with_proxy_success(self): - """Test proxy API call with successful response (status 200).""" with ( patch("framex.plugin._manager.get_api", return_value=None), patch("framex.plugin.settings.server.enable_proxy", True), patch("framex.plugin.get_adapter") as mock_adapter, ): - # Simulate proxy response - proxy_response = {"status": 200, "data": {"result": "proxy_success"}} - mock_adapter.return_value.call_func = AsyncMock(return_value=proxy_response) - + mock_adapter.return_value.call_func = AsyncMock(return_value={"status": 200, "data": {"result": "ok"}}) result = await call_plugin_api("/external/api") - - # Should return just the data field - assert result == {"result": "proxy_success"} + assert result == {"result": "ok"} @pytest.mark.asyncio async def test_call_plugin_api_with_proxy_empty_data(self): - """Test proxy API call that returns empty data with warning.""" with ( patch("framex.plugin._manager.get_api", return_value=None), patch("framex.plugin.settings.server.enable_proxy", True), patch("framex.plugin.get_adapter") as mock_adapter, patch("framex.plugin.logger") as mock_logger, ): - proxy_response = {"status": 200, "data": None} - mock_adapter.return_value.call_func = AsyncMock(return_value=proxy_response) - + mock_adapter.return_value.call_func = AsyncMock(return_value={"status": 200, "data": None}) result = await call_plugin_api("/external/api") - assert result is None - # Verify warning was logged mock_logger.opt.return_value.warning.assert_called() @pytest.mark.asyncio async def test_call_plugin_api_with_proxy_error_status(self): - """Test proxy API call with non-200 status logs error.""" with ( patch("framex.plugin._manager.get_api", return_value=None), patch("framex.plugin.settings.server.enable_proxy", True), patch("framex.plugin.get_adapter") as mock_adapter, patch("framex.plugin.logger") as mock_logger, ): - proxy_response = {"status": 500, "data": None} - mock_adapter.return_value.call_func = AsyncMock(return_value=proxy_response) + mock_adapter.return_value.call_func = AsyncMock(return_value={"status": 500, "data": None}) with pytest.raises(RuntimeError, match="Proxy API /external/api returned status 500"): await call_plugin_api("/external/api") - - # Verify error was logged mock_logger.opt.return_value.error.assert_called() @pytest.mark.asyncio async def test_call_plugin_api_not_found_no_proxy(self): - """Test API not found when proxy is disabled raises RuntimeError.""" with ( - patch("framex.plugin._manager.get_api", return_value=None), patch("framex.plugin.settings.server.enable_proxy", False), pytest.raises(RuntimeError, match="API test_api is not found"), ): @@ -139,167 +75,131 @@ async def test_call_plugin_api_not_found_no_proxy(self): @pytest.mark.asyncio async def test_call_plugin_api_not_found_non_slash_with_proxy(self): - """Test non-slash prefixed API not found with proxy enabled raises error.""" with ( - patch("framex.plugin._manager.get_api", return_value=None), patch("framex.plugin.settings.server.enable_proxy", True), pytest.raises(RuntimeError, match="API test_api is not found"), ): await call_plugin_api("test_api") @pytest.mark.asyncio - async def test_call_plugin_api_with_dict_to_basemodel_conversion(self): - """Test automatic conversion of dict parameters to BaseModel.""" - api = PluginApi(api="test_api", deployment_name="test_deployment", params=[("model_param", SampleModel)]) - - with ( - patch("framex.plugin._manager.get_api", return_value=api), - patch("framex.plugin.get_adapter") as mock_adapter, - ): - mock_adapter.return_value.call_func = AsyncMock(return_value="success") - - # Pass dict that should be converted to SampleModel - result = await call_plugin_api("test_api", model_param={"field1": "test", "field2": 456}) - - # Verify the call was made and dict was converted - assert result == "success" - call_args = mock_adapter.return_value.call_func.call_args - assert isinstance(call_args[1]["model_param"], SampleModel) + async def test_call_plugin_api_requires_remote_apis_in_context(self): + token = set_current_remote_apis({}) + try: + with pytest.raises(RuntimeError, match="not declared in current plugin remote_apis"): + await call_plugin_api("test_api") + finally: + reset_current_remote_apis(token) @pytest.mark.asyncio - async def test_call_plugin_api_with_interval_apis(self): - """Test using interval_apis parameter to override manager lookup.""" + async def test_call_plugin_api_with_remote_apis_dict_to_basemodel_conversion(self): + api = PluginApi(api="test_api", deployment_name="test_deployment", params=[("model_param", SampleModel)]) + token = set_current_remote_apis({"test_api": api}) + try: + with patch("framex.plugin.get_adapter") as mock_adapter: + mock_adapter.return_value.call_func = AsyncMock(return_value="success") + result = await call_plugin_api("test_api", model_param={"field1": "test", "field2": 456}) + assert result == "success" + call_args = mock_adapter.return_value.call_func.call_args + assert isinstance(call_args[1]["model_param"], SampleModel) + finally: + reset_current_remote_apis(token) + + @pytest.mark.asyncio + async def test_call_plugin_api_with_remote_apis(self): api = PluginApi(api="test_api", deployment_name="test_deployment") - interval_apis = {"test_api": api} - - with ( - patch("framex.plugin._manager.get_api") as mock_get_api, - patch("framex.plugin.get_adapter") as mock_adapter, - ): - mock_adapter.return_value.call_func = AsyncMock(return_value="interval_result") - - token = set_current_remote_apis(interval_apis) - try: + token = set_current_remote_apis({"test_api": api}) + try: + with ( + patch("framex.plugin._manager.get_api") as mock_get_api, + patch("framex.plugin.get_adapter") as mock_adapter, + ): + mock_adapter.return_value.call_func = AsyncMock(return_value="remote_result") result = await call_plugin_api("test_api") - finally: - reset_current_remote_apis(token) - - # Manager get_api should not be called - mock_get_api.assert_not_called() - assert result == "interval_result" + mock_get_api.assert_not_called() + assert result == "remote_result" + finally: + reset_current_remote_apis(token) @pytest.mark.asyncio - async def test_call_plugin_api_with_interval_apis_dict_payload(self): - """Test interval_apis values serialized through Ray can be rehydrated.""" + async def test_call_plugin_api_with_remote_apis_dict_payload(self): api = PluginApi(api="test_api", deployment_name="test_deployment") - interval_apis = {"test_api": api.model_dump()} - - with ( - patch("framex.plugin._manager.get_api") as mock_get_api, - patch("framex.plugin.get_adapter") as mock_adapter, - ): - mock_adapter.return_value.call_func = AsyncMock(return_value="interval_dict_result") - - token = set_current_remote_apis(interval_apis) - try: + token = set_current_remote_apis({"test_api": api.model_dump()}) + try: + with ( + patch("framex.plugin._manager.get_api") as mock_get_api, + patch("framex.plugin.get_adapter") as mock_adapter, + ): + mock_adapter.return_value.call_func = AsyncMock(return_value="remote_dict_result") result = await call_plugin_api("test_api") - finally: - reset_current_remote_apis(token) - - mock_get_api.assert_not_called() - assert result == "interval_dict_result" + mock_get_api.assert_not_called() + assert result == "remote_dict_result" + finally: + reset_current_remote_apis(token) @pytest.mark.asyncio - async def test_call_plugin_api_uses_resolver_when_manager_misses(self): - """Test ApiResolver fallback works when Ray workers lack manager state.""" - func_api = PluginApi(deployment_name="demo_plugin.DemoDeployment", func_name="run", call_type=ApiType.FUNC) - http_api = PluginApi(api="/api/v1/resource_match/resource_detail", deployment_name="resource_match.Detail") - resolver = ApiResolver( - api_registry={ - "demo_plugin.DemoDeployment.run": func_api, - "/api/v1/resource_match/resource_detail": http_api, - } - ) + async def test_call_plugin_api_context_wrapper_preserves_async_generator(self): + from framex.plugin.base import BasePlugin + from framex.plugin.on import on_request - with ( - patch("framex.plugin._manager.get_api", return_value=None), - patch("framex.plugin.get_adapter") as mock_adapter, - ): - mock_adapter.return_value.call_func = AsyncMock(side_effect=["func_result", "http_result"]) + class DemoPlugin(BasePlugin): + @on_request("/demo", stream=True) + async def stream_api(self): + yield "a" + yield "b" - assert await call_plugin_api("demo_plugin.DemoDeployment.run", resolver=resolver) == "func_result" - assert await call_plugin_api("/api/v1/resource_match/resource_detail", resolver=resolver) == "http_result" + plugin = DemoPlugin() + stream = plugin.stream_api() + + assert inspect.isasyncgen(stream) + assert [chunk async for chunk in stream] == ["a", "b"] @pytest.mark.asyncio - async def test_call_plugin_api_uses_context_resolver_when_no_explicit_resolver(self): - """Test context-bound resolver works for nested service-style calls.""" - http_api = PluginApi(api="/api/v1/fetch/fetch_company_by_dim_info", deployment_name="fetch.FetchPlugin") - resolver = ApiResolver(api_registry={"/api/v1/fetch/fetch_company_by_dim_info": http_api}) + async def test_call_remote_api_uses_whitelist_via_request_context(self): + from framex.plugin.base import BasePlugin + from framex.plugin.on import on_request - with ( - patch("framex.plugin._manager.get_api", return_value=None), - patch("framex.plugin.get_adapter") as mock_adapter, - ): - mock_adapter.return_value.call_func = AsyncMock(return_value="context_result") - token = set_current_api_resolver(resolver) - try: - result = await call_plugin_api("/api/v1/fetch/fetch_company_by_dim_info") - finally: - reset_current_api_resolver(token) + api = PluginApi(api="test_api", deployment_name="test_deployment", params=[("model_param", SampleModel)]) - assert result == "context_result" - assert mock_adapter.return_value.call_func.call_args[0][0] == http_api + class DemoPlugin(BasePlugin): + @on_request("/demo") + async def request_api(self): + return await self._call_remote_api("test_api", model_param={"field1": "test", "field2": 456}) - @pytest.mark.asyncio - async def test_call_plugin_api_plugin_context_requires_remote_apis(self): - """Test plugin-context calls do not fall back to global registry outside remote_apis.""" - http_api = PluginApi(api="/api/v1/fetch/fetch_company_by_dim_info", deployment_name="fetch.FetchPlugin") - resolver = ApiResolver(api_registry={"/api/v1/fetch/fetch_company_by_dim_info": http_api}) - - with patch("framex.plugin._manager.get_api", return_value=http_api): - resolver_token = set_current_api_resolver(resolver) - remote_token = set_current_remote_apis({}) - try: - with pytest.raises(RuntimeError, match="not declared in current plugin remote_apis"): - await call_plugin_api("/api/v1/fetch/fetch_company_by_dim_info") - finally: - reset_current_remote_apis(remote_token) - reset_current_api_resolver(resolver_token) + plugin = DemoPlugin(remote_apis={"test_api": api}) + + with patch("framex.plugin.get_adapter") as mock_adapter: + mock_adapter.return_value.call_func = AsyncMock(return_value="success") + result = await plugin.request_api() + assert result == "success" + call_args = mock_adapter.return_value.call_func.call_args + assert isinstance(call_args[0][0], PluginApi) + assert isinstance(call_args[1]["model_param"], SampleModel) @pytest.mark.asyncio - async def test_call_plugin_api_context_wrapper_preserves_async_generator(self): - """Test resolver context wrapping keeps stream handlers as async generators.""" + async def test_call_remote_api_rejects_missing_whitelist_entry_via_request_context(self): from framex.plugin.base import BasePlugin from framex.plugin.on import on_request class DemoPlugin(BasePlugin): - @on_request("/demo", stream=True) - async def stream_api(self): - yield "a" - yield "b" + @on_request("/demo") + async def request_api(self): + return await self._call_remote_api("test_api") - plugin = DemoPlugin(api_registry={}) - stream = plugin.stream_api() + plugin = DemoPlugin(remote_apis={}) - assert inspect.isasyncgen(stream) - assert [chunk async for chunk in stream] == ["a", "b"] + with pytest.raises(RuntimeError, match="not declared in current plugin remote_apis"): + await plugin.request_api() @pytest.mark.asyncio async def test_call_plugin_api_proxy_creates_correct_plugin_api(self): - """Test that proxy fallback creates PluginApi with correct parameters.""" with ( - patch("framex.plugin._manager.get_api", return_value=None), patch("framex.plugin.settings.server.enable_proxy", True), patch("framex.plugin.get_adapter") as mock_adapter, patch("framex.plugin.logger"), ): mock_adapter.return_value.call_func = AsyncMock(return_value={"status": 200, "data": "ok"}) - await call_plugin_api("/proxy/test") - - # Check the PluginApi passed to call_func - call_args = mock_adapter.return_value.call_func.call_args - api = call_args[0][0] + api = mock_adapter.return_value.call_func.call_args[0][0] assert isinstance(api, PluginApi) assert api.api == "/proxy/test" assert api.deployment_name == PROXY_PLUGIN_NAME @@ -307,47 +207,37 @@ async def test_call_plugin_api_proxy_creates_correct_plugin_api(self): @pytest.mark.asyncio async def test_call_plugin_api_regular_dict_result_not_proxy(self): - """Test that regular dict results (non-proxy) are returned as-is.""" api = PluginApi(api="test_api", deployment_name="test_deployment") - - with ( - patch("framex.plugin._manager.get_api", return_value=api), - patch("framex.plugin.get_adapter") as mock_adapter, - ): - # Regular dict result (not from proxy) - mock_adapter.return_value.call_func = AsyncMock(return_value={"key": "value", "status": 200}) - - result = await call_plugin_api("test_api") - - # Should return the entire dict, not extract "data" - assert result == {"key": "value", "status": 200} + token = set_current_remote_apis({"test_api": api}) + try: + with patch("framex.plugin.get_adapter") as mock_adapter: + mock_adapter.return_value.call_func = AsyncMock(return_value={"key": "value", "status": 200}) + result = await call_plugin_api("test_api") + assert result == {"key": "value", "status": 200} + finally: + reset_current_remote_apis(token) @pytest.mark.asyncio async def test_call_plugin_api_with_multiple_kwargs(self): - """Test calling API with multiple keyword arguments.""" api = PluginApi( api="test_api", deployment_name="test_deployment", params=[("a", int), ("b", str), ("c", bool)] ) - - with ( - patch("framex.plugin._manager.get_api", return_value=api), - patch("framex.plugin.get_adapter") as mock_adapter, - ): - mock_adapter.return_value.call_func = AsyncMock(return_value="multi_args") - - result = await call_plugin_api("test_api", a=1, b="test", c=True) - - assert result == "multi_args" - call_kwargs = mock_adapter.return_value.call_func.call_args[1] - assert call_kwargs["a"] == 1 - assert call_kwargs["b"] == "test" - assert call_kwargs["c"] is True + token = set_current_remote_apis({"test_api": api}) + try: + with patch("framex.plugin.get_adapter") as mock_adapter: + mock_adapter.return_value.call_func = AsyncMock(return_value="multi_args") + result = await call_plugin_api("test_api", a=1, b="test", c=True) + assert result == "multi_args" + call_kwargs = mock_adapter.return_value.call_func.call_args[1] + assert call_kwargs["a"] == 1 + assert call_kwargs["b"] == "test" + assert call_kwargs["c"] is True + finally: + reset_current_remote_apis(token) @pytest.mark.asyncio async def test_call_plugin_api_with_proxy_missing_status(self): - """Test proxy API call raises when status field is missing.""" with ( - patch("framex.plugin._manager.get_api", return_value=None), patch("framex.plugin.settings.server.enable_proxy", True), patch("framex.plugin.get_adapter") as mock_adapter, patch("framex.plugin.logger"), From 3a043436afa1a428ab0a1292c541259bce1ba513 Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Sat, 28 Mar 2026 16:50:45 +0800 Subject: [PATCH 5/6] fix: simplify proxy content type handling --- src/framex/plugins/proxy/__init__.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/framex/plugins/proxy/__init__.py b/src/framex/plugins/proxy/__init__.py index 90fea99..f9c6972 100644 --- a/src/framex/plugins/proxy/__init__.py +++ b/src/framex/plugins/proxy/__init__.py @@ -142,18 +142,14 @@ async def _parse_openai_docs(self, url: str) -> None: Model = create_pydantic_model(schema_name, model_schema, components) # noqa params.append(("model", Model)) body_param_names.add("model") - elif content_type == "multipart/form-data": + else: for field_name, prop_schema in model_schema.get("properties", {}).items(): annotation = resolve_annotation(prop_schema, components) params.append((field_name, to_multipart_annotation(annotation))) body_param_names.add(field_name) if is_upload_annotation(annotation): file_param_names.add(field_name) - else: - logger.opt(colors=True).error( - f"Failed to proxy api({method}) {url}{path}, unsupported content type: ${body_content.keys()}" - ) - continue + logger.opt(colors=True).trace(f"Found proxy api({method}) {url}{path}") is_stream = path in settings.force_stream_apis func = self._create_dynamic_method( From 825a0304dbd1ebebe5effeab5b1a8edbe0e81f6e Mon Sep 17 00:00:00 2001 From: touale <136764239@qq.com> Date: Sat, 28 Mar 2026 16:55:05 +0800 Subject: [PATCH 6/6] fix: remove unused annotation helpers --- src/framex/driver/ingress.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/framex/driver/ingress.py b/src/framex/driver/ingress.py index a2ddb85..8fa7af4 100644 --- a/src/framex/driver/ingress.py +++ b/src/framex/driver/ingress.py @@ -2,12 +2,11 @@ import os import re from collections.abc import Callable -from typing import Annotated, Any, get_args, get_origin +from typing import Any from fastapi import Depends, HTTPException, Request, Response, status from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute -from pydantic import BaseModel from starlette.routing import Route from framex.adapter import get_adapter @@ -22,17 +21,6 @@ app = create_fastapi_application() -def _unwrap_annotation(annotation: Any) -> Any: - if get_origin(annotation) is Annotated: - return get_args(annotation)[0] - return annotation - - -def _is_basemodel_annotation(annotation: Any) -> bool: - annotation = _unwrap_annotation(annotation) - return isinstance(annotation, type) and issubclass(annotation, BaseModel) - - @app.get("/health") async def health() -> str: return "ok"