diff --git a/.coveragerc b/.coveragerc index 48d7d40..406fd24 100644 --- a/.coveragerc +++ b/.coveragerc @@ -23,5 +23,5 @@ exclude_also = @(abc\.)?abstractmethod pass def __str__ - if settings.server.use_ray: + # if settings.server.use_ray: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c07dbac..afbedb9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ jobs: - name: Install UV run: | pip install uv==${{ env.UV_VERSION }} - uv sync --group dev + uv sync --extra ray --dev - name: Run Tests run: uv run poe test-ci - name: Upload coverage reports to Codecov diff --git a/ruff.toml b/ruff.toml index fe30ebd..874ef7a 100644 --- a/ruff.toml +++ b/ruff.toml @@ -79,7 +79,15 @@ ignore = [ ] [lint.per-file-ignores] -"tests/**/*.py" = ["S101", "ANN201", "ANN001", "ANN002", "ANN003"] +"tests/**/*.py" = [ + "S101", + "ANN201", + "ANN001", + "ANN002", + "ANN003", + "ANN202", + "ARG001", +] [lint.flake8-builtins] builtins-ignorelist = ["input", "id", "bytes", "type"] diff --git a/src/framex/adapter/__init__.py b/src/framex/adapter/__init__.py index db98f36..a4e7094 100644 --- a/src/framex/adapter/__init__.py +++ b/src/framex/adapter/__init__.py @@ -1,5 +1,5 @@ from framex.adapter.base import BaseAdapter -from framex.adapter.local_adapyer import LocalAdapter +from framex.adapter.local_adapter import LocalAdapter from framex.config import settings _adapter: BaseAdapter | None = None diff --git a/src/framex/adapter/base.py b/src/framex/adapter/base.py index fafd978..cda5bad 100644 --- a/src/framex/adapter/base.py +++ b/src/framex/adapter/base.py @@ -1,5 +1,4 @@ import abc -import inspect from collections.abc import Callable from enum import StrEnum from typing import Any, cast @@ -25,17 +24,21 @@ def to_ingress(self, cls: type, app: FastAPI, **kwargs: Any) -> type: # noqa: A def to_deployment(self, cls: type, **kwargs: Any) -> type: # noqa: ARG002 return cls - async def call_func(self, api: PluginApi, **kwargs: Any) -> Any: - func = self.get_handle_func(api.deployment_name, api.func_name) - stream = api.stream + async def _resolve_stream(self, api: PluginApi, kwargs: dict[str, Any]) -> bool: if api.call_type == ApiType.PROXY and api.api: kwargs["proxy_path"] = api.api - stream = await self._check_is_gen_api(api.api) + return bool(await self._check_is_gen_api(api.api)) + return api.stream + + @abc.abstractmethod + async def _invoke(self, func: Callable[..., Any], **kwargs: Any) -> Any: ... + + async def call_func(self, api: PluginApi, **kwargs: Any) -> Any: + func = self.get_handle_func(api.deployment_name, api.func_name) + stream = await self._resolve_stream(api, kwargs) if stream: return [chunk async for chunk in self._stream_call(func, **kwargs)] - if inspect.iscoroutinefunction(func): - return await self._acall(func, **kwargs) # type: ignore - return self._call(func, **kwargs) + return await self._invoke(func, **kwargs) def get_handle_func(self, deployment_name: str, func_name: str) -> Any: handle = self.get_handle(deployment_name) @@ -57,11 +60,10 @@ def bind(self, deployment: Callable[..., Any], **kwargs: Any) -> Any: ... @abc.abstractmethod def to_remote_func(self, func: Callable) -> Callable: ... - def _stream_call(self, func: Callable[..., Any], **kwargs: Any) -> Any: - return func(**kwargs) - - async def _acall(self, func: Callable[..., Any], **kwargs: Any) -> Any: - return await func(**kwargs) + @abc.abstractmethod + def _stream_call(self, func: Callable[..., Any], **kwargs: Any) -> Any: ... - def _call(self, func: Callable[..., Any], **kwargs: Any) -> Any: - return func(**kwargs) + @abc.abstractmethod + async def _acall(self, func: Callable[..., Any], **kwargs: Any) -> Any: ... + @abc.abstractmethod + def _call(self, func: Callable[..., Any], **kwargs: Any) -> Any: ... diff --git a/src/framex/adapter/local_adapyer.py b/src/framex/adapter/local_adapter.py similarity index 76% rename from src/framex/adapter/local_adapyer.py rename to src/framex/adapter/local_adapter.py index 7f05dfb..3acf0b1 100644 --- a/src/framex/adapter/local_adapyer.py +++ b/src/framex/adapter/local_adapter.py @@ -55,3 +55,21 @@ async def _remote_func(*args: tuple[Any, ...], **kwargs: Any) -> Any: func.remote = _remote_func # type: ignore[attr-defined] return func + + @override + async def _invoke(self, func: Callable[..., Any], **kwargs: Any) -> Any: + if inspect.iscoroutinefunction(func): + return await self._acall(func, **kwargs) + return self._call(func, **kwargs) + + @override + async def _acall(self, func: Callable[..., Any], **kwargs: Any) -> Any: + return await func(**kwargs) + + @override + def _call(self, func: Callable[..., Any], **kwargs: Any) -> Any: + return func(**kwargs) + + @override + def _stream_call(self, func: Callable[..., Any], **kwargs: Any) -> Any: + return func(**kwargs) diff --git a/src/framex/adapter/ray_adapter.py b/src/framex/adapter/ray_adapter.py index 89fb325..73a9afd 100644 --- a/src/framex/adapter/ray_adapter.py +++ b/src/framex/adapter/ray_adapter.py @@ -6,8 +6,9 @@ try: import ray # type: ignore[import-not-found] from ray import serve # type: ignore[import-not-found] + from ray.serve.handle import DeploymentHandle # type: ignore[import-not-found] except ImportError as e: - raise RuntimeError('Ray engine requires extra dependency.\nInstall with: uv pip install "framex-kit[ray]"') from e + raise RuntimeError('Ray engine requires extra dependency.\nInstall with: uv add "framex-kit[ray]"') from e from fastapi import FastAPI from typing_extensions import override @@ -55,3 +56,13 @@ def _stream_call(self, func: Callable[..., Any], **kwargs: Any) -> Any: @override async def _acall(self, func: Callable[..., Any], **kwargs: Any) -> Any: return await func.remote(**kwargs) # type: ignore [attr-defined] + + @override + async def _invoke(self, func: Callable[..., Any], **kwargs: Any) -> Any: + if inspect.iscoroutinefunction(func) or isinstance(func, DeploymentHandle): + return await self._acall(func, **kwargs) # type: ignore + return self._call(func, **kwargs) + + @override + def _call(self, func: Callable[..., Any], **kwargs: Any) -> Any: + return func(**kwargs) diff --git a/tests/adapter/__init__.py b/tests/adapter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/adapter/test_init.py b/tests/adapter/test_init.py new file mode 100644 index 0000000..ae84f12 --- /dev/null +++ b/tests/adapter/test_init.py @@ -0,0 +1,112 @@ +"""Tests for framex.adapter.__init__ module.""" + +from unittest.mock import MagicMock, patch + +from framex.adapter import get_adapter +from framex.adapter.base import BaseAdapter +from framex.adapter.local_adapter import LocalAdapter + + +class TestGetAdapter: + """Tests for the get_adapter factory function.""" + + def setup_method(self): + """Reset the global adapter before each test.""" + import framex.adapter as adapter_module + + adapter_module._adapter = None + + def test_get_adapter_returns_local_adapter_when_ray_disabled(self): + """Test get_adapter returns LocalAdapter when use_ray is False.""" + with patch("framex.adapter.settings.server.use_ray", False): + adapter = get_adapter() + assert isinstance(adapter, LocalAdapter) + assert isinstance(adapter, BaseAdapter) + + def test_get_adapter_returns_ray_adapter_when_ray_enabled(self): + """Test get_adapter returns RayAdapter when use_ray is True.""" + with ( + patch("framex.adapter.settings.server.use_ray", True), + patch("framex.adapter.ray_adapter.RayAdapter") as mock_ray_adapter, + ): + mock_instance = MagicMock() + mock_ray_adapter.return_value = mock_instance + + adapter = get_adapter() + assert adapter == mock_instance + mock_ray_adapter.assert_called_once() + + def test_get_adapter_returns_same_instance_on_multiple_calls(self): + """Test get_adapter returns the same singleton instance.""" + with patch("framex.adapter.settings.server.use_ray", False): + adapter1 = get_adapter() + adapter2 = get_adapter() + assert adapter1 is adapter2 + + def test_get_adapter_caches_local_adapter(self): + """Test that LocalAdapter is cached after first call.""" + with patch("framex.adapter.settings.server.use_ray", False): + adapter1 = get_adapter() + # Second call should return cached instance + adapter2 = get_adapter() + assert adapter1 is adapter2 + assert isinstance(adapter1, LocalAdapter) + + def test_get_adapter_caches_ray_adapter(self): + """Test that RayAdapter is cached after first call.""" + with ( + patch("framex.adapter.settings.server.use_ray", True), + patch("framex.adapter.ray_adapter.RayAdapter") as mock_ray_adapter, + ): + mock_instance = MagicMock() + mock_ray_adapter.return_value = mock_instance + + adapter1 = get_adapter() + adapter2 = get_adapter() + + # Should only instantiate once + assert adapter1 is adapter2 + assert mock_ray_adapter.call_count == 1 + + def test_get_adapter_lazy_imports_ray_adapter(self): + """Test that RayAdapter is only imported when needed.""" + with patch("framex.adapter.settings.server.use_ray", False): # noqa + # Import should not happen when use_ray is False + with patch("framex.adapter.ray_adapter") as mock_ray_module: + get_adapter() + # RayAdapter module should not be accessed + mock_ray_module.RayAdapter.assert_not_called() + + def test_adapter_initially_none(self): + """Test that _adapter global is None before first call.""" + import framex.adapter as adapter_module + + adapter_module._adapter = None + assert adapter_module._adapter is None + + def test_adapter_set_after_first_call(self): + """Test that _adapter global is set after first call.""" + import framex.adapter as adapter_module + + adapter_module._adapter = None + with patch("framex.adapter.settings.server.use_ray", False): + get_adapter() + assert adapter_module._adapter is not None + assert isinstance(adapter_module._adapter, LocalAdapter) + + def test_get_adapter_with_switching_ray_setting(self): + """Test that once adapter is set, changing ray setting doesn't affect it.""" + import framex.adapter as adapter_module + + adapter_module._adapter = None + + with patch("framex.adapter.settings.server.use_ray", False): + adapter1 = get_adapter() + assert isinstance(adapter1, LocalAdapter) + + # Change setting, but adapter should still be the same + with patch("framex.adapter.settings.server.use_ray", True): + adapter2 = get_adapter() + # Should still be the same LocalAdapter instance + assert adapter2 is adapter1 + assert isinstance(adapter2, LocalAdapter) diff --git a/tests/adapter/test_local_adapter.py b/tests/adapter/test_local_adapter.py new file mode 100644 index 0000000..6c32f95 --- /dev/null +++ b/tests/adapter/test_local_adapter.py @@ -0,0 +1,328 @@ +"""Tests for framex.adapter.local_adapter module.""" + +import asyncio +import threading +from unittest.mock import MagicMock, patch + +from framex.adapter.base import AdapterMode +from framex.adapter.local_adapter import LocalAdapter + + +class TestLocalAdapter: + """Tests for LocalAdapter class.""" + + def test_mode_is_local(self): + """Test LocalAdapter mode is LOCAL.""" + adapter = LocalAdapter() + assert adapter.mode == AdapterMode.LOCAL + + def test_to_deployment_sets_deployment_name(self): + """Test to_deployment sets deployment_name attribute on class.""" + adapter = LocalAdapter() + + class TestClass: + pass + + result = adapter.to_deployment(TestClass, name="my_deployment") + assert hasattr(result, "deployment_name") + assert result.deployment_name == "my_deployment" + + def test_to_deployment_without_name(self): + """Test to_deployment without name doesn't set attribute.""" + adapter = LocalAdapter() + + class TestClass: + pass + + result = adapter.to_deployment(TestClass) + assert not hasattr(result, "deployment_name") + + def test_to_deployment_with_other_kwargs(self): + """Test to_deployment handles other kwargs.""" + adapter = LocalAdapter() + + class TestClass: + pass + + result = adapter.to_deployment(TestClass, name="test", other_param="value") + assert hasattr(result, "deployment_name") + assert result.deployment_name == "test" + + def test_get_handle_returns_deployment(self): + """Test get_handle returns deployment from app state.""" + adapter = LocalAdapter() + + mock_deployment = MagicMock() + mock_app = MagicMock() + mock_app.state.deployments_dict = {"test_deployment": mock_deployment} + + with patch("framex.driver.ingress.app", mock_app): + handle = adapter.get_handle("test_deployment") + assert handle == mock_deployment + + def test_get_handle_returns_ingress_for_backend(self): + """Test get_handle returns ingress for BACKEND_NAME.""" + adapter = LocalAdapter() + + mock_ingress = MagicMock() + mock_app = MagicMock() + mock_app.state.deployments_dict = {} + mock_app.state.ingress = mock_ingress + + with ( + patch("framex.driver.ingress.app", mock_app), + patch("framex.consts.BACKEND_NAME", "backend"), + ): + handle = adapter.get_handle("backend") + assert handle == mock_ingress + + def test_get_handle_returns_none_when_not_found(self): + """Test get_handle returns None when deployment not found.""" + adapter = LocalAdapter() + + mock_app = MagicMock() + mock_app.state.deployments_dict = {} + + with patch("framex.driver.ingress.app", mock_app): + handle = adapter.get_handle("nonexistent") + assert handle is None + + def test_bind_calls_deployment_with_kwargs(self): + """Test bind calls deployment function with kwargs.""" + adapter = LocalAdapter() + + def mock_deployment(**kwargs): + return kwargs + + result = adapter.bind(mock_deployment, param1="value1", param2="value2") + assert result == {"param1": "value1", "param2": "value2"} + + def test_bind_with_no_kwargs(self): + """Test bind calls deployment with no kwargs.""" + adapter = LocalAdapter() + + def mock_deployment(**kwargs): + return "called" + + result = adapter.bind(mock_deployment) + assert result == "called" + + def test_safe_plot_wrapper_uses_lock(self): + """Test _safe_plot_wrapper uses thread lock.""" + adapter = LocalAdapter() + call_order = [] + lock_acquired = [] + + def mock_func(*args, **kwargs): + # Check if lock is held + lock_acquired.append(threading.current_thread().ident) + call_order.append("func") + return "result" + + result = adapter._safe_plot_wrapper(mock_func, "arg1", kwarg1="value1") + assert result == "result" + assert len(lock_acquired) == 1 + + def test_safe_plot_wrapper_serializes_calls(self): + """Test _safe_plot_wrapper serializes concurrent calls.""" + adapter = LocalAdapter() + call_order = [] + lock = threading.Lock() # noqa + + def slow_func(name): + call_order.append(f"{name}_start") + threading.Event().wait(0.01) # Small delay + call_order.append(f"{name}_end") + return name + + # Run two calls concurrently + def run_wrapper(name): + return adapter._safe_plot_wrapper(slow_func, name) + + thread1 = threading.Thread(target=run_wrapper, args=("call1",)) + thread2 = threading.Thread(target=run_wrapper, args=("call2",)) + + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + # Both calls should complete (in any order due to thread scheduling) + assert "call1_start" in call_order + assert "call1_end" in call_order + assert "call2_start" in call_order + assert "call2_end" in call_order + + async def test_to_remote_func_with_async_function(self): + """Test to_remote_func handles async functions.""" + adapter = LocalAdapter() + + async def async_func(value): + return value * 2 + + wrapped_func = adapter.to_remote_func(async_func) + assert hasattr(wrapped_func, "remote") + + result = await wrapped_func.remote(5) + assert result == 10 + + async def test_to_remote_func_with_sync_function(self): + """Test to_remote_func handles sync functions with asyncio.to_thread.""" + adapter = LocalAdapter() + + def sync_func(value): + return value * 3 + + wrapped_func = adapter.to_remote_func(sync_func) + assert hasattr(wrapped_func, "remote") + + result = await wrapped_func.remote(5) + assert result == 15 + + async def test_to_remote_func_with_sync_function_uses_safe_wrapper(self): + """Test to_remote_func uses _safe_plot_wrapper for sync functions.""" + adapter = LocalAdapter() + + def sync_func(value): + return value + 1 + + with patch.object(adapter, "_safe_plot_wrapper", return_value=11) as mock_wrapper: + wrapped_func = adapter.to_remote_func(sync_func) + result = await wrapped_func.remote(10) # type: ignore + + # Verify safe wrapper was used + mock_wrapper.assert_called_once_with(sync_func, 10) + assert result == 11 + + async def test_invoke_with_async_function(self): + """Test _invoke delegates to _acall for async functions.""" + adapter = LocalAdapter() + + async def async_func(**kwargs): + return "async_result" + + result = await adapter._invoke(async_func, param="value") + assert result == "async_result" + + async def test_invoke_with_sync_function(self): + """Test _invoke delegates to _call for sync functions.""" + adapter = LocalAdapter() + + def sync_func(**kwargs): + return "sync_result" + + result = await adapter._invoke(sync_func, param="value") + assert result == "sync_result" + + async def test_acall_awaits_async_function(self): + """Test _acall awaits async function with kwargs.""" + adapter = LocalAdapter() + + async def async_func(**kwargs): + return kwargs + + result = await adapter._acall(async_func, key1="value1", key2="value2") + assert result == {"key1": "value1", "key2": "value2"} + + def test_call_invokes_sync_function(self): + """Test _call invokes sync function with kwargs.""" + adapter = LocalAdapter() + + def sync_func(**kwargs): + return kwargs + + result = adapter._call(sync_func, key1="value1", key2="value2") + assert result == {"key1": "value1", "key2": "value2"} + + def test_stream_call_invokes_function(self): + """Test _stream_call invokes function with kwargs.""" + adapter = LocalAdapter() + + def stream_func(**kwargs): + yield from kwargs.values() + + result = adapter._stream_call(stream_func, key1="value1", key2="value2") + values = list(result) + assert "value1" in values + assert "value2" in values + + async def test_stream_call_with_async_generator(self): + """Test _stream_call works with async generators.""" + adapter = LocalAdapter() + + async def async_stream(**kwargs): + for value in kwargs.values(): + yield value + + result = adapter._stream_call(async_stream, key1="value1", key2="value2") + values = [] + async for value in result: + values.append(value) # noqa + assert "value1" in values + assert "value2" in values + + def test_to_remote_func_preserves_original_function(self): + """Test to_remote_func preserves the original function.""" + adapter = LocalAdapter() + + def original_func(x): + return x * 2 + + wrapped_func = adapter.to_remote_func(original_func) + # Original function should still be callable + assert wrapped_func(5) == 10 + + async def test_to_remote_func_remote_attribute(self): + """Test to_remote_func adds remote attribute.""" + adapter = LocalAdapter() + + def sync_func(x): + return x + 1 + + wrapped_func = adapter.to_remote_func(sync_func) + assert hasattr(wrapped_func, "remote") + assert callable(wrapped_func.remote) + + async def test_concurrent_safe_plot_wrapper_calls(self): + """Test multiple concurrent calls to _safe_plot_wrapper are serialized.""" + adapter = LocalAdapter() + results = [] + + def counting_func(n): + import time + + time.sleep(0.001) # Small delay to ensure overlap without lock + results.append(n) + return n + + async def async_wrapper(n): + return await asyncio.to_thread(adapter._safe_plot_wrapper, counting_func, n) + + # Run multiple concurrent calls + await asyncio.gather(*[async_wrapper(i) for i in range(5)]) + + # All calls should complete + assert len(results) == 5 + assert set(results) == {0, 1, 2, 3, 4} + + def test_get_handle_with_empty_deployments_dict(self): + """Test get_handle with empty deployments dict.""" + adapter = LocalAdapter() + + mock_app = MagicMock() + mock_app.state.deployments_dict = {} + mock_app.state.ingress = None + + with patch("framex.driver.ingress.app", mock_app): + handle = adapter.get_handle("any_deployment") + assert handle is None + + async def test_invoke_preserves_kwargs(self): + """Test _invoke preserves kwargs correctly.""" + adapter = LocalAdapter() + + async def async_func(**kwargs): + return kwargs + + result = await adapter._invoke(async_func, a=1, b=2, c=3) + assert result == {"a": 1, "b": 2, "c": 3} diff --git a/tests/adapter/test_ray_adapter.py b/tests/adapter/test_ray_adapter.py new file mode 100644 index 0000000..ea52834 --- /dev/null +++ b/tests/adapter/test_ray_adapter.py @@ -0,0 +1,417 @@ +"""Tests for framex.adapter.ray_adapter module.""" + +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI + +from framex.adapter.base import AdapterMode + + +class TestRayAdapterImport: + """Tests for RayAdapter import error handling.""" + + def test_import_error_without_ray(self): + """Test that importing RayAdapter without ray raises RuntimeError.""" + # Remove ray from sys.modules if it exists + ray_modules = [key for key in sys.modules if key.startswith("ray")] + saved_modules = {} + for mod in ray_modules: + saved_modules[mod] = sys.modules.pop(mod, None) + + try: + with patch.dict("sys.modules", {"ray": None, "ray.serve": None}): # noqa + with pytest.raises(RuntimeError, match="Ray engine requires extra dependency"): # noqa + # Force reload to trigger ImportError + import importlib + + import framex.adapter.ray_adapter + + importlib.reload(framex.adapter.ray_adapter) + finally: + # Restore modules + for mod, val in saved_modules.items(): + if val is not None: + sys.modules[mod] = val + + +@pytest.fixture +def mock_ray(): + """Mock ray and ray.serve modules.""" + mock_ray_module = MagicMock() + mock_serve_module = MagicMock() + mock_deployment_handle = MagicMock() + + with ( + patch.dict( + "sys.modules", + { + "ray": mock_ray_module, + "ray.serve": mock_serve_module, + "ray.serve.handle": MagicMock(DeploymentHandle=mock_deployment_handle), + }, + ), + patch("framex.adapter.ray_adapter.ray", mock_ray_module), + patch("framex.adapter.ray_adapter.serve", mock_serve_module), + patch("framex.adapter.ray_adapter.DeploymentHandle", mock_deployment_handle), + ): + yield mock_ray_module, mock_serve_module, mock_deployment_handle + + +class TestRayAdapter: + """Tests for RayAdapter class with mocked ray dependencies.""" + + def test_mode_is_ray(self, mock_ray): # noqa + """Test RayAdapter mode is RAY.""" + from framex.adapter.ray_adapter import RayAdapter + + adapter = RayAdapter() + assert adapter.mode == AdapterMode.RAY + + def test_to_ingress_calls_serve_ingress(self, mock_ray): + """Test to_ingress calls serve.ingress and to_deployment.""" + from framex.adapter.ray_adapter import RayAdapter + + _, mock_serve_module, _ = mock_ray + adapter = RayAdapter() + + class TestClass: + pass + + app = FastAPI() + + # Setup mock chain + mock_ingress_decorator = MagicMock() + mock_ingress_result = MagicMock() + mock_ingress_decorator.return_value = mock_ingress_result + mock_serve_module.ingress.return_value = mock_ingress_decorator + + mock_deployment_decorator = MagicMock() + mock_deployment_result = MagicMock() + mock_deployment_decorator.return_value = mock_deployment_result + mock_serve_module.deployment.return_value = mock_deployment_decorator + + adapter.to_ingress(TestClass, app, name="test") + + mock_serve_module.ingress.assert_called_once_with(app) + mock_ingress_decorator.assert_called_once_with(TestClass) + mock_serve_module.deployment.assert_called_once_with(name="test") + + def test_to_deployment_calls_serve_deployment(self, mock_ray): + """Test to_deployment calls serve.deployment with kwargs.""" + from framex.adapter.ray_adapter import RayAdapter + + _, mock_serve_module, _ = mock_ray + adapter = RayAdapter() + + class TestClass: + pass + + mock_decorator = MagicMock() + mock_result = MagicMock() + mock_decorator.return_value = mock_result + mock_serve_module.deployment.return_value = mock_decorator + + adapter.to_deployment(TestClass, name="test", num_replicas=3) + + mock_serve_module.deployment.assert_called_once_with(name="test", num_replicas=3) + mock_decorator.assert_called_once_with(TestClass) + + def test_to_remote_func_with_sync_function(self, mock_ray): + """Test to_remote_func wraps sync function with ray.remote.""" + from framex.adapter.ray_adapter import RayAdapter + + mock_ray_module, _, _ = mock_ray + adapter = RayAdapter() + + def sync_func(x): + return x * 2 + + mock_remote_func = MagicMock() + mock_ray_module.remote.return_value = mock_remote_func + + result = adapter.to_remote_func(sync_func) + + mock_ray_module.remote.assert_called_once() + assert result == mock_remote_func + + def test_to_remote_func_with_async_function(self, mock_ray): + """Test to_remote_func wraps async function to sync then ray.remote.""" + from framex.adapter.ray_adapter import RayAdapter + + mock_ray_module, _, _ = mock_ray + adapter = RayAdapter() + + async def async_func(x): + return x * 2 + + mock_remote_func = MagicMock() + mock_ray_module.remote.return_value = mock_remote_func + + result = adapter.to_remote_func(async_func) + + # Should wrap async function and call ray.remote + mock_ray_module.remote.assert_called_once() + assert result == mock_remote_func + + def test_to_remote_func_async_wrapper_runs_asyncio(self, mock_ray): + """Test that async wrapper uses asyncio.run.""" + from framex.adapter.ray_adapter import RayAdapter + + mock_ray_module, _, _ = mock_ray + adapter = RayAdapter() + + call_log = [] + + async def async_func(x): + call_log.append(f"called with {x}") + return x * 2 + + # Capture the wrapper function passed to ray.remote + captured_wrapper = None + + def capture_remote(func): + nonlocal captured_wrapper + captured_wrapper = func + return MagicMock() + + mock_ray_module.remote = capture_remote + + adapter.to_remote_func(async_func) + + # Now test the captured wrapper + assert captured_wrapper is not None + with patch("asyncio.run") as mock_asyncio_run: + mock_asyncio_run.return_value = 10 + result = captured_wrapper(5) + mock_asyncio_run.assert_called_once() + assert result == 10 + + def test_get_handle_calls_serve_get_deployment_handle(self, mock_ray): + """Test get_handle calls serve.get_deployment_handle.""" + from framex.adapter.ray_adapter import RayAdapter + + _, mock_serve_module, _ = mock_ray + adapter = RayAdapter() + + mock_handle = MagicMock() + mock_serve_module.get_deployment_handle.return_value = mock_handle + + with patch("framex.adapter.ray_adapter.APP_NAME", "test_app"): + result = adapter.get_handle("test_deployment") + + mock_serve_module.get_deployment_handle.assert_called_once_with("test_deployment", app_name="test_app") + assert result == mock_handle + + def test_bind_calls_deployment_bind(self, mock_ray): # noqa + """Test bind calls deployment.bind with kwargs.""" + from framex.adapter.ray_adapter import RayAdapter + + adapter = RayAdapter() + + mock_deployment = MagicMock() + mock_bound = MagicMock() + mock_deployment.bind.return_value = mock_bound + + result = adapter.bind(mock_deployment, param1="value1", param2="value2") + + mock_deployment.bind.assert_called_once_with(param1="value1", param2="value2") + assert result == mock_bound + + def test_stream_call_uses_options_stream(self, mock_ray): # noqa + """Test _stream_call uses options(stream=True).remote.""" + from framex.adapter.ray_adapter import RayAdapter + + adapter = RayAdapter() + + mock_func = MagicMock() + mock_options = MagicMock() + mock_remote_result = MagicMock() + mock_func.options.return_value = mock_options + mock_options.remote.return_value = mock_remote_result + + result = adapter._stream_call(mock_func, param="value") + + mock_func.options.assert_called_once_with(stream=True) + mock_options.remote.assert_called_once_with(param="value") + assert result == mock_remote_result + + async def test_acall_awaits_remote(self, mock_ray): + """Test _acall awaits func.remote.""" + from framex.adapter.ray_adapter import RayAdapter + + _, _, _ = mock_ray + adapter = RayAdapter() + + mock_func = MagicMock() + mock_remote = AsyncMock(return_value="result") + mock_func.remote = mock_remote + + result = await adapter._acall(mock_func, param="value") + + mock_remote.assert_called_once_with(param="value") + assert result == "result" + + async def test_invoke_with_async_function(self, mock_ray): + """Test _invoke delegates to _acall for async functions.""" + from framex.adapter.ray_adapter import RayAdapter + + _, _, mock_deployment_handle = mock_ray + adapter = RayAdapter() + + async def async_func(**kwargs): + return "async_result" + + # Patch isinstance to always return False for DeploymentHandle check + with ( + patch( + "framex.adapter.ray_adapter.isinstance", + side_effect=lambda obj, cls: False if cls is mock_deployment_handle else isinstance(obj, cls), + ), + patch.object(adapter, "_acall", new=AsyncMock(return_value="async_result")) as mock_acall, + ): + result = await adapter._invoke(async_func, param="value") + mock_acall.assert_called_once_with(async_func, param="value") + assert result == "async_result" + + async def test_invoke_with_deployment_handle(self, mock_ray): + """Test _invoke delegates to _acall for DeploymentHandle-like object.""" + from framex.adapter.ray_adapter import RayAdapter + + _, _, mock_deployment_handle = mock_ray + adapter = RayAdapter() + + # Create a mock that will match isinstance check + # We'll patch inspect.iscoroutinefunction to return False + # and the isinstance(func, DeploymentHandle) will be checked + mock_handle = MagicMock() + + with ( + patch("inspect.iscoroutinefunction", return_value=False), + patch( + "framex.adapter.ray_adapter.isinstance", + side_effect=lambda obj, cls: obj is mock_handle and cls is mock_deployment_handle, + ), + patch.object(adapter, "_acall", new=AsyncMock(return_value="handle_result")) as mock_acall, + ): + result = await adapter._invoke(mock_handle, param="value") + mock_acall.assert_called_once_with(mock_handle, param="value") + assert result == "handle_result" + + async def test_invoke_with_sync_function(self, mock_ray): + """Test _invoke delegates to _call for sync functions.""" + from framex.adapter.ray_adapter import RayAdapter + + adapter = RayAdapter() + _, _, mock_deployment_handle = mock_ray + + def sync_func(**kwargs): + return "sync_result" + + # Patch isinstance to always return False for DeploymentHandle check + with ( + patch( + "framex.adapter.ray_adapter.isinstance", + side_effect=lambda obj, cls: False if cls is mock_deployment_handle else isinstance(obj, cls), + ), + patch.object(adapter, "_call", return_value="sync_result") as mock_call, + ): + result = await adapter._invoke(sync_func, param="value") + mock_call.assert_called_once_with(sync_func, param="value") + assert result == "sync_result" + + def test_call_invokes_function_directly(self, mock_ray): # noqa + """Test _call invokes function directly with kwargs.""" + from framex.adapter.ray_adapter import RayAdapter + + adapter = RayAdapter() + + def mock_func(**kwargs): + return kwargs + + result = adapter._call(mock_func, key1="value1", key2="value2") + assert result == {"key1": "value1", "key2": "value2"} + + def test_bind_with_no_kwargs(self, mock_ray): # noqa + """Test bind with no kwargs.""" + from framex.adapter.ray_adapter import RayAdapter + + adapter = RayAdapter() + + mock_deployment = MagicMock() + mock_bound = MagicMock() + mock_deployment.bind.return_value = mock_bound + + result = adapter.bind(mock_deployment) + + mock_deployment.bind.assert_called_once_with() + assert result == mock_bound + + def test_stream_call_with_multiple_kwargs(self, mock_ray): # noqa + """Test _stream_call passes all kwargs correctly.""" + from framex.adapter.ray_adapter import RayAdapter + + adapter = RayAdapter() + + mock_func = MagicMock() + mock_options = MagicMock() + mock_remote_result = MagicMock() + mock_func.options.return_value = mock_options + mock_options.remote.return_value = mock_remote_result + + result = adapter._stream_call(mock_func, a=1, b=2, c=3) + + mock_options.remote.assert_called_once_with(a=1, b=2, c=3) + assert result == mock_remote_result + + def test_to_deployment_with_no_kwargs(self, mock_ray): + """Test to_deployment with no kwargs.""" + from framex.adapter.ray_adapter import RayAdapter + + _, mock_serve_module, _ = mock_ray + adapter = RayAdapter() + + class TestClass: + pass + + mock_decorator = MagicMock() + mock_result = MagicMock() + mock_decorator.return_value = mock_result + mock_serve_module.deployment.return_value = mock_decorator + + adapter.to_deployment(TestClass) + + mock_serve_module.deployment.assert_called_once_with() + mock_decorator.assert_called_once_with(TestClass) + + async def test_acall_with_no_kwargs(self, mock_ray): # noqa + """Test _acall with no kwargs.""" + from framex.adapter.ray_adapter import RayAdapter + + adapter = RayAdapter() + + mock_func = MagicMock() + mock_remote = AsyncMock(return_value="result") + mock_func.remote = mock_remote + + result = await adapter._acall(mock_func) + + mock_remote.assert_called_once_with() + assert result == "result" + + def test_get_handle_with_different_app_name(self, mock_ray): + """Test get_handle uses APP_NAME constant correctly.""" + from framex.adapter.ray_adapter import RayAdapter + + _, mock_serve_module, _ = mock_ray + adapter = RayAdapter() + + mock_handle = MagicMock() + mock_serve_module.get_deployment_handle.return_value = mock_handle + + with patch("framex.adapter.ray_adapter.APP_NAME", "custom_app"): + result = adapter.get_handle("my_deployment") + + mock_serve_module.get_deployment_handle.assert_called_once_with("my_deployment", app_name="custom_app") + assert result == mock_handle diff --git a/tests/conftest.py b/tests/conftest.py index 8ebca33..567862b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ def vcr_config(): @pytest.fixture(scope="module") -def disable_recording(request): # noqa +def disable_recording(request): return settings.test.disable_record_request diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 3746a4c..d7ca5cf 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -268,7 +268,7 @@ async def local_exchange_key_value(a_str: str, b_int: int, c_model: ExchangeMode @on_proxy() -async def remote_exchange_key_value(a_str: str, b_int: int, c_model: ExchangeModel) -> Any: # noqa: ARG001 +async def remote_exchange_key_value(a_str: str, b_int: int, c_model: ExchangeModel) -> Any: raise RuntimeError("This function should be called remotely")