Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ exclude_also =
@(abc\.)?abstractmethod
pass
def __str__
if settings.server.use_ray:
# if settings.server.use_ray:

2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Install UV
run: |
pip install uv==${{ env.UV_VERSION }}
uv sync --group dev
uv sync --extra ray --dev
- name: Run Tests
run: uv run poe test-ci
- name: Upload coverage reports to Codecov
Expand Down
10 changes: 9 additions & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,15 @@ ignore = [
]

[lint.per-file-ignores]
"tests/**/*.py" = ["S101", "ANN201", "ANN001", "ANN002", "ANN003"]
"tests/**/*.py" = [
"S101",
"ANN201",
"ANN001",
"ANN002",
"ANN003",
"ANN202",
"ARG001",
]

[lint.flake8-builtins]
builtins-ignorelist = ["input", "id", "bytes", "type"]
Expand Down
2 changes: 1 addition & 1 deletion src/framex/adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from framex.adapter.base import BaseAdapter
from framex.adapter.local_adapyer import LocalAdapter
from framex.adapter.local_adapter import LocalAdapter
from framex.config import settings

_adapter: BaseAdapter | None = None
Expand Down
32 changes: 17 additions & 15 deletions src/framex/adapter/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import inspect
from collections.abc import Callable
from enum import StrEnum
from typing import Any, cast
Expand All @@ -25,17 +24,21 @@ def to_ingress(self, cls: type, app: FastAPI, **kwargs: Any) -> type: # noqa: A
def to_deployment(self, cls: type, **kwargs: Any) -> type: # noqa: ARG002
return cls

async def call_func(self, api: PluginApi, **kwargs: Any) -> Any:
func = self.get_handle_func(api.deployment_name, api.func_name)
stream = api.stream
async def _resolve_stream(self, api: PluginApi, kwargs: dict[str, Any]) -> bool:
if api.call_type == ApiType.PROXY and api.api:
kwargs["proxy_path"] = api.api
stream = await self._check_is_gen_api(api.api)
return bool(await self._check_is_gen_api(api.api))
return api.stream

@abc.abstractmethod
async def _invoke(self, func: Callable[..., Any], **kwargs: Any) -> Any: ...

async def call_func(self, api: PluginApi, **kwargs: Any) -> Any:
func = self.get_handle_func(api.deployment_name, api.func_name)
stream = await self._resolve_stream(api, kwargs)
if stream:
return [chunk async for chunk in self._stream_call(func, **kwargs)]
if inspect.iscoroutinefunction(func):
return await self._acall(func, **kwargs) # type: ignore
return self._call(func, **kwargs)
return await self._invoke(func, **kwargs)

def get_handle_func(self, deployment_name: str, func_name: str) -> Any:
handle = self.get_handle(deployment_name)
Expand All @@ -57,11 +60,10 @@ def bind(self, deployment: Callable[..., Any], **kwargs: Any) -> Any: ...
@abc.abstractmethod
def to_remote_func(self, func: Callable) -> Callable: ...

def _stream_call(self, func: Callable[..., Any], **kwargs: Any) -> Any:
return func(**kwargs)

async def _acall(self, func: Callable[..., Any], **kwargs: Any) -> Any:
return await func(**kwargs)
@abc.abstractmethod
def _stream_call(self, func: Callable[..., Any], **kwargs: Any) -> Any: ...

def _call(self, func: Callable[..., Any], **kwargs: Any) -> Any:
return func(**kwargs)
@abc.abstractmethod
async def _acall(self, func: Callable[..., Any], **kwargs: Any) -> Any: ...
@abc.abstractmethod
def _call(self, func: Callable[..., Any], **kwargs: Any) -> Any: ...
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,21 @@ async def _remote_func(*args: tuple[Any, ...], **kwargs: Any) -> Any:

func.remote = _remote_func # type: ignore[attr-defined]
return func

@override
async def _invoke(self, func: Callable[..., Any], **kwargs: Any) -> Any:
if inspect.iscoroutinefunction(func):
return await self._acall(func, **kwargs)
return self._call(func, **kwargs)

@override
async def _acall(self, func: Callable[..., Any], **kwargs: Any) -> Any:
return await func(**kwargs)

@override
def _call(self, func: Callable[..., Any], **kwargs: Any) -> Any:
return func(**kwargs)

@override
def _stream_call(self, func: Callable[..., Any], **kwargs: Any) -> Any:
return func(**kwargs)
13 changes: 12 additions & 1 deletion src/framex/adapter/ray_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
try:
import ray # type: ignore[import-not-found]
from ray import serve # type: ignore[import-not-found]
from ray.serve.handle import DeploymentHandle # type: ignore[import-not-found]
except ImportError as e:
raise RuntimeError('Ray engine requires extra dependency.\nInstall with: uv pip install "framex-kit[ray]"') from e
raise RuntimeError('Ray engine requires extra dependency.\nInstall with: uv add "framex-kit[ray]"') from e
from fastapi import FastAPI
from typing_extensions import override

Expand Down Expand Up @@ -55,3 +56,13 @@ def _stream_call(self, func: Callable[..., Any], **kwargs: Any) -> Any:
@override
async def _acall(self, func: Callable[..., Any], **kwargs: Any) -> Any:
return await func.remote(**kwargs) # type: ignore [attr-defined]

@override
async def _invoke(self, func: Callable[..., Any], **kwargs: Any) -> Any:
if inspect.iscoroutinefunction(func) or isinstance(func, DeploymentHandle):
return await self._acall(func, **kwargs) # type: ignore
return self._call(func, **kwargs)

@override
def _call(self, func: Callable[..., Any], **kwargs: Any) -> Any:
return func(**kwargs)
Comment on lines +66 to +68
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's explore the file structure
ls -la src/framex/adapter/

# Look at the ray_adapter.py file around the _call method
cat -n src/framex/adapter/ray_adapter.py | head -100

Repository: touale/FrameX-kit

Length of output: 3356


🏁 Script executed:

#!/bin/bash
# Check what get_handle returns and how it's used
rg -n 'get_handle\b' src/framex/adapter/ray_adapter.py -A 5 -B 2

Repository: touale/FrameX-kit

Length of output: 394


🏁 Script executed:

#!/bin/bash
# Check the _invoke method and routing logic
rg -n '_invoke|_acall|_call' src/framex/adapter/ray_adapter.py -B 2 -A 5

Repository: touale/FrameX-kit

Length of output: 860


🏁 Script executed:

#!/bin/bash
# Look at base.py to understand get_handle_func
rg -n 'get_handle_func\|get_handle' src/framex/adapter/base.py -A 5 -B 2

Repository: touale/FrameX-kit

Length of output: 43


🏁 Script executed:

#!/bin/bash
# Check the base.py implementation
cat -n src/framex/adapter/base.py

Repository: touale/FrameX-kit

Length of output: 2984


🏁 Script executed:

#!/bin/bash
# Search for where _invoke is called from
rg -n '_invoke\|get_handle' src/framex/ -A 3 -B 2 --type py

Repository: touale/FrameX-kit

Length of output: 43


🏁 Script executed:

#!/bin/bash
# Check local_adapter.py to see how it implements _call
cat -n src/framex/adapter/local_adapter.py

Repository: touale/FrameX-kit

Length of output: 3112


🏁 Script executed:

#!/antml/bash
# Search for all places where _invoke is called
rg -n 'await.*_invoke\|_invoke\(' src/framex/ -B 3 -A 2 --type py

Repository: touale/FrameX-kit

Length of output: 43


🏁 Script executed:

#!/bin/bash
# Check for any sync function calls or patterns
rg -n '_call\(' src/framex/adapter/ -B 3 -A 2 --type py

Repository: touale/FrameX-kit

Length of output: 3419


🏁 Script executed:

#!/bin/bash
# Search for tests or usage of _invoke to understand what functions are passed
rg -n 'call_func\|_invoke' src/framex/ -B 5 -A 5 --type py | head -100

Repository: touale/FrameX-kit

Length of output: 43


🏁 Script executed:

#!/bin/bash
# Check if there are any tests that exercise the _call path
find src/framex -name "*.py" -type f | xargs rg -l 'test.*_invoke\|test.*call_func' 2>/dev/null

Repository: touale/FrameX-kit

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Look for documentation or comments about Ray DeploymentHandle behavior
rg -n 'DeploymentHandle\|method.*ref\|\.remote' src/framex/ --type py

Repository: touale/FrameX-kit

Length of output: 43


RayAdapter._call is reachable but broken for sync remote methods.

When get_handle_func calls getattr(deployment_handle, func_name) on a Ray DeploymentHandle, it returns a method reference—not another DeploymentHandle. For a sync (non-async) method on a Ray deployment:

  • isinstance(func, DeploymentHandle) → False
  • inspect.iscoroutinefunction(func) → False
  • Falls through to _call

However, the current _call implementation (return func(**kwargs)) is incorrect for Ray. Ray method references require .remote() to execute remotely (see _acall's func.remote(**kwargs)). Calling func(**kwargs) directly will fail at runtime because it attempts local execution instead of remote execution.

The _call path should either:

  1. Use .remote() like _acall does, or
  2. Raise an error if sync remote methods are not supported
🤖 Prompt for AI Agents
In `@src/framex/adapter/ray_adapter.py` around lines 66 - 68, RayAdapter._call
currently invokes func(**kwargs) which fails for Ray DeploymentHandle method
references; change RayAdapter._call to call func.remote(**kwargs) (mirroring
_acall) and return its result (or, if you prefer to disallow sync calls, raise
NotImplementedError with a clear message). Update the implementation referenced
by RayAdapter._call so it treats Ray deployment method references like _acall
does (using func.remote) and keep
get_handle_func/deployment_handle/DeploymentHandle logic unchanged.

Empty file added tests/adapter/__init__.py
Empty file.
112 changes: 112 additions & 0 deletions tests/adapter/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Tests for framex.adapter.__init__ module."""

from unittest.mock import MagicMock, patch

from framex.adapter import get_adapter
from framex.adapter.base import BaseAdapter
from framex.adapter.local_adapter import LocalAdapter


class TestGetAdapter:
"""Tests for the get_adapter factory function."""

def setup_method(self):
"""Reset the global adapter before each test."""
import framex.adapter as adapter_module

adapter_module._adapter = None

def test_get_adapter_returns_local_adapter_when_ray_disabled(self):
"""Test get_adapter returns LocalAdapter when use_ray is False."""
with patch("framex.adapter.settings.server.use_ray", False):
adapter = get_adapter()
assert isinstance(adapter, LocalAdapter)
assert isinstance(adapter, BaseAdapter)

def test_get_adapter_returns_ray_adapter_when_ray_enabled(self):
"""Test get_adapter returns RayAdapter when use_ray is True."""
with (
patch("framex.adapter.settings.server.use_ray", True),
patch("framex.adapter.ray_adapter.RayAdapter") as mock_ray_adapter,
):
mock_instance = MagicMock()
mock_ray_adapter.return_value = mock_instance

adapter = get_adapter()
assert adapter == mock_instance
mock_ray_adapter.assert_called_once()

def test_get_adapter_returns_same_instance_on_multiple_calls(self):
"""Test get_adapter returns the same singleton instance."""
with patch("framex.adapter.settings.server.use_ray", False):
adapter1 = get_adapter()
adapter2 = get_adapter()
assert adapter1 is adapter2

def test_get_adapter_caches_local_adapter(self):
"""Test that LocalAdapter is cached after first call."""
with patch("framex.adapter.settings.server.use_ray", False):
adapter1 = get_adapter()
# Second call should return cached instance
adapter2 = get_adapter()
assert adapter1 is adapter2
assert isinstance(adapter1, LocalAdapter)

def test_get_adapter_caches_ray_adapter(self):
"""Test that RayAdapter is cached after first call."""
with (
patch("framex.adapter.settings.server.use_ray", True),
patch("framex.adapter.ray_adapter.RayAdapter") as mock_ray_adapter,
):
mock_instance = MagicMock()
mock_ray_adapter.return_value = mock_instance

adapter1 = get_adapter()
adapter2 = get_adapter()

# Should only instantiate once
assert adapter1 is adapter2
assert mock_ray_adapter.call_count == 1

def test_get_adapter_lazy_imports_ray_adapter(self):
"""Test that RayAdapter is only imported when needed."""
with patch("framex.adapter.settings.server.use_ray", False): # noqa
# Import should not happen when use_ray is False
with patch("framex.adapter.ray_adapter") as mock_ray_module:
get_adapter()
# RayAdapter module should not be accessed
mock_ray_module.RayAdapter.assert_not_called()

def test_adapter_initially_none(self):
"""Test that _adapter global is None before first call."""
import framex.adapter as adapter_module

adapter_module._adapter = None
assert adapter_module._adapter is None

def test_adapter_set_after_first_call(self):
"""Test that _adapter global is set after first call."""
import framex.adapter as adapter_module

adapter_module._adapter = None
with patch("framex.adapter.settings.server.use_ray", False):
get_adapter()
assert adapter_module._adapter is not None
assert isinstance(adapter_module._adapter, LocalAdapter)

def test_get_adapter_with_switching_ray_setting(self):
"""Test that once adapter is set, changing ray setting doesn't affect it."""
import framex.adapter as adapter_module

adapter_module._adapter = None

with patch("framex.adapter.settings.server.use_ray", False):
adapter1 = get_adapter()
assert isinstance(adapter1, LocalAdapter)

# Change setting, but adapter should still be the same
with patch("framex.adapter.settings.server.use_ray", True):
adapter2 = get_adapter()
# Should still be the same LocalAdapter instance
assert adapter2 is adapter1
assert isinstance(adapter2, LocalAdapter)
Loading
Loading